diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..7c73d9f38 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,22 @@ +name: CI + +on: + pull_request: {} + +jobs: + build: + name: Build branch + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Build + run: mvn verify diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/publish-snapshot.yml similarity index 98% rename from .github/workflows/continuous-integration.yml rename to .github/workflows/publish-snapshot.yml index e0939f087..5d9b4aa39 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/publish-snapshot.yml @@ -1,4 +1,4 @@ -name: CI/CD build +name: Publish Snapshot on: push: diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..6009a645f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,119 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a +harassment-free experience for everyone, regardless of age, body size, visible or +invisible disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, +inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community +include: + +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and + learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +- The use of sexualized language or imagery, and sexual attention or advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without + their explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional + setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in response to +any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, +commits, code, wiki edits, issues, and other contributions that are not aligned to this +Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an +individual is officially representing the community in public spaces. Examples of +representing our community include using an official e-mail address, posting via an +official social media account, or acting as an appointed representative at an online or +offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to +the community leaders responsible for enforcement at mcp-coc@anthropic.com. All +complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter +of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the +consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity +around the nature of the violation and an explanation of why the behavior was +inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, for a specified period of time. This includes avoiding interactions in community +spaces as well as external channels like social media. Violating these terms may lead to +a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained +inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication +with the community for a specified period of time. No public or private interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead to a permanent +ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, +including sustained inappropriate behavior, harassment of an individual, or aggression +toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, +available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..517f32555 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,94 @@ +# Contributing to Model Context Protocol Java SDK + +Thank you for your interest in contributing to the Model Context Protocol Java SDK! +This document outlines how to contribute to this project. + +## Prerequisites + +The following software is required to work on the codebase: + +- `Java 17` or above +- `Docker` +- `npx` + +## Getting Started + +1. Fork the repository +2. Clone your fork: + +```bash +git clone https://github.com/YOUR-USERNAME/java-sdk.git +cd java-sdk +``` + +3. Build from source: + +```bash +./mvnw clean install -DskipTests # skip the tests +./mvnw test # run tests +``` + +## Reporting Issues + +Please create an issue in the repository if you discover a bug or would like to +propose an enhancement. Bug reports should have a reproducer in the form of a code +sample or a repository attached that the maintainers or contributors can work with to +address the problem. + +## Making Changes + +1. Create a new branch: + +```bash +git checkout -b feature/your-feature-name +``` + +2. Make your changes +3. Validate your changes: + +```bash +./mvnw clean test +``` + +### Change Proposal Guidelines + +#### Principles of MCP + +1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to + remove them. To maintain simplicity, we keep a high bar for adding new concepts and + primitives as each addition requires maintenance and compatibility consideration. +2. **Concrete**: Code changes need to be based on specific usage and implementation + challenges and not on speculative ideas. Most importantly, the SDK is meant to + implement the MCP specification. + +## Submitting Changes + +1. For non-trivial changes, please clarify with the maintainers in an issue whether + you can contribute the change and the desired scope of the change. +2. For trivial changes (for example a couple of lines or documentation changes) there + is no need to open an issue first. +3. Push your changes to your fork. +4. Submit a pull request to the main repository. +5. Follow the pull request template. +6. Wait for review. +7. For any follow-up work, please add new commits instead of force-pushing. This will + allow the reviewer to focus on incremental changes instead of having to restart the + review process. + +## Code of Conduct + +This project follows a Code of Conduct. Please review it in +[CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). + +## Questions + +If you have questions, please create a discussion in the repository. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT +License. + +## Security + +Please review our [Security Policy](SECURITY.md) for reporting security issues. \ No newline at end of file diff --git a/README.md b/README.md index caa6bf0c0..7bda15006 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,23 @@ # MCP Java SDK -[![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/continuous-integration.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/continuous-integration.yml) +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/license/MIT) +[![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml) +[![Maven Central](https://img.shields.io/maven-central/v/io.modelcontextprotocol.sdk/mcp.svg?label=Maven%20Central)](https://central.sonatype.com/artifact/io.modelcontextprotocol.sdk/mcp) +[![Java Version](https://img.shields.io/badge/Java-17%2B-orange)](https://www.oracle.com/java/technologies/javase/jdk17-archive-downloads.html) + A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. -## 📚 Reference Documentation +## 📚 Reference Documentation #### MCP Java SDK documentation -For comprehensive guides and SDK API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview). +For comprehensive guides and SDK API documentation + +- [Features](https://modelcontextprotocol.io/sdk/java/mcp-overview#features) - Overview the features provided by the Java MCP SDK +- [Architecture](https://modelcontextprotocol.io/sdk/java/mcp-overview#architecture) - Java MCP SDK architecture overview. +- [Java Dependencies / BOM](https://modelcontextprotocol.io/sdk/java/mcp-overview#dependencies) - Java dependencies and BOM. +- [Java MCP Client](https://modelcontextprotocol.io/sdk/java/mcp-client) - Learn how to use the MCP client to interact with MCP servers. +- [Java MCP Server](https://modelcontextprotocol.io/sdk/java/mcp-server) - Learn how to implement and configure a MCP servers. #### Spring AI MCP documentation [Spring AI MCP](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-overview.html) extends the MCP Java SDK with Spring Boot integration, providing both [client](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-client-boot-starter-docs.html) and [server](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-server-boot-starter-docs.html) starters. Bootstrap your AI applications with MCP support using [Spring Initializer](https://start.spring.io). @@ -30,16 +40,14 @@ To run the tests you have to pre-install `Docker` and `npx`. ## Contributing -Contributions are welcome! Please: - -1. Fork the repository -2. Create a feature branch -3. Submit a Pull Request +Contributions are welcome! +Please follow the [Contributing Guidelines](CONTRIBUTING.md). ## Team - Christian Tzolov - Dariusz Jędrzejczyk +- Daniel Garnier-Moiroux ## Links @@ -47,6 +55,133 @@ Contributions are welcome! Please: - [Issue Tracker](https://github.com/modelcontextprotocol/java-sdk/issues) - [CI/CD](https://github.com/modelcontextprotocol/java-sdk/actions) +## Architecture and Design Decisions + +### Introduction + +Building a general-purpose MCP Java SDK requires making technology decisions in areas where the JDK provides limited or no support. The Java ecosystem is powerful but fragmented: multiple valid approaches exist, each with strong communities. +Our goal is not to prescribe "the one true way," but to provide a reference implementation of the MCP specification that is: + +* **Pragmatic** – makes developers productive quickly +* **Interoperable** – aligns with widely used libraries and practices +* **Pluggable** – allows alternatives where projects prefer different stacks +* **Grounded in team familiarity** – we chose technologies the team can be productive with today, while remaining open to community contributions that broaden the SDK + +### Key Choices and Considerations + +The SDK had to make decisions in the following areas: + +1. **JSON serialization** – mapping between JSON and Java types + +2. **Programming model** – supporting asynchronous processing, cancellation, and streaming while staying simple for blocking use cases + +3. **Observability** – logging and enabling integration with metrics/tracing + +4. **Remote clients and servers** – supporting both consuming MCP servers (client transport) and exposing MCP endpoints (server transport with authorization) + +The following sections explain what we chose, why it made sense, and how the choices align with the SDK's goals. + +### 1. JSON Serialization + +* **SDK Choice**: Jackson for JSON serialization and deserialization, behind an SDK abstraction (`mcp-json`) + +* **Why**: Jackson is widely adopted across the Java ecosystem, provides strong performance and a mature annotation model, and is familiar to the SDK team and many potential contributors. + +* **How we expose it**: Public APIs use a zero-dependency abstraction (`mcp-json`). Jackson is shipped as the default implementation (`mcp-jackson2`), but alternatives can be plugged in. + +* **How it fits the SDK**: This offers a pragmatic default while keeping flexibility for projects that prefer different JSON libraries. + +### 2. Programming Model + +* **SDK Choice**: Reactive Streams for public APIs, with Project Reactor as the internal implementation and a synchronous facade for blocking use cases + +* **Why**: MCP builds on JSON-RPC's asynchronous nature and defines a bidirectional protocol on top of it, enabling asynchronous and streaming interactions. MCP explicitly supports: + + * Multiple in-flight requests and responses + * Notifications that do not expect a reply + * STDIO transports for inter-process communication using pipes + * Streaming transports such as Server-Sent Events and Streamable HTTP + + These requirements call for a programming model more powerful than single-result futures like `CompletableFuture`. + + * **Reactive Streams: the Community Standard** + + Reactive Streams is a small Java specification that standardizes asynchronous stream processing with backpressure. It defines four minimal interfaces (Publisher, Subscriber, Subscription, and Processor). These interfaces are widely recognized as the standard contract for async, non-blocking pipelines in Java. + + * **Reactive Streams Implementation** + + The SDK uses Project Reactor as its implementation of the Reactive Streams specification. Reactor is mature, widely adopted, provides rich operators, and integrates well with observability through context propagation. Team familiarity also allowed us to deliver a solid foundation quickly. + We plan to convert the public API to only expose Reactive Streams interfaces. By defining the public API in terms of Reactive Streams interfaces and using Reactor internally, the SDK stays standards-based while benefiting from a practical, production-ready implementation. + + * **Synchronous Facade in the SDK** + + Not all MCP use cases require streaming pipelines. Many scenarios are as simple as "send a request and block until I get the result." + To support this, the SDK provides a synchronous facade layered on top of the reactive core. Developers can stay in a blocking model when it's enough, while still having access to asynchronous streaming when needed. + +* **How it fits the SDK**: This design balances scalability, approachability, and future evolution such as Virtual Threads and Structured Concurrency in upcoming JDKs. + +### 3. Observability + +* **SDK Choice**: SLF4J for logging; Reactor Context for observability propagation + +* **Why**: SLF4J is the de facto logging facade in Java, with broad compatibility. Reactor Context enables propagation of observability data such as correlation IDs and tracing state across async boundaries. This ensures interoperability with modern observability frameworks. + +* **How we expose it**: Public APIs log through SLF4J only, with no backend included. Observability metadata flows through Reactor pipelines. The SDK itself does not ship metrics or tracing implementations. + +* **How it fits the SDK**: This provides reliable logging by default and seamless integration with Micrometer, OpenTelemetry, or similar systems for metrics and tracing. + +### 4. Remote MCP Clients and Servers + +MCP supports both clients (applications consuming MCP servers) and servers (applications exposing MCP endpoints). The SDK provides support for both sides. + +#### Client Transport in the SDK + +* **SDK Choice**: JDK HttpClient (Java 11+) as the default client, with optional Spring WebClient support + +* **Why**: The JDK HttpClient is built-in, portable, and supports streaming responses. This keeps the default lightweight with no extra dependencies. Spring WebClient support is available for Spring-based projects. + +* **How we expose it**: MCP Client APIs are transport-agnostic. The core module ships with JDK HttpClient transport. A Spring module provides WebClient integration. + +* **How it fits the SDK**: This ensures all applications can talk to MCP servers out of the box, while allowing richer integration in Spring and other environments. + +#### Server Transport in the SDK + +* **SDK Choice**: Jakarta Servlet implementation in core, with optional Spring WebFlux and Spring WebMVC providers + +* **Why**: Servlet is the most widely deployed Java server API. WebFlux and WebMVC cover a significant part of the Spring community. Together these provide reach across blocking and non-blocking models. + +* **How we expose it**: Server APIs are transport-agnostic. Core includes Servlet support. Spring modules extend support for WebFlux and WebMVC. + +* **How it fits the SDK**: This allows developers to expose MCP servers in the most common Java environments today, while enabling other transport implementations such as Netty, Vert.x, or Helidon. + +#### Authorization in the SDK + +* **SDK Choice**: Pluggable authorization hooks for MCP servers; no built-in implementation + +* **Why**: MCP servers must restrict access to authenticated and authorized clients. Authorization needs differ across environments such as Spring Security, MicroProfile JWT, or custom solutions. Providing hooks avoids lock-in and leverages proven libraries. + +* **How we expose it**: Authorization is integrated into the server transport layer. The SDK does not include its own authorization system. + +* **How it fits the SDK**: This keeps server-side security ecosystem-neutral, while ensuring applications can plug in their preferred authorization strategy. + +### Project Structure of the SDK + +The SDK is organized into modules to separate concerns and allow adopters to bring in only what they need: +* `mcp-bom` – Dependency versions +* `mcp-core` – Reference implementation (STDIO, JDK HttpClient, Servlet) +* `mcp-json` – JSON abstraction +* `mcp-jackson2` – Jackson implementation of JSON binding +* `mcp` – Convenience bundle (core + Jackson) +* `mcp-test` – Shared testing utilities +* `mcp-spring` – Spring integrations (WebClient, WebFlux, WebMVC) + +For example, a minimal adopter may depend only on `mcp` (core + Jackson), while a Spring-based application can use `mcp-spring` for deeper framework integration. + +### Future Directions + +The SDK is designed to evolve with the Java ecosystem. Areas we are actively watching include: +Concurrency in the JDK – Virtual Threads and Structured Concurrency may simplify the synchronous API story + ## License This project is licensed under the [MIT License](LICENSE). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..74e9880fd --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +Thank you for helping us keep the SDKs and systems they interact with secure. + +## Reporting Security Issues + +This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model +Context Protocol project. + +The security of our systems and user data is Anthropic’s top priority. We appreciate the +work of security researchers acting in good faith in identifying and reporting potential +vulnerabilities. + +Our security program is managed on HackerOne and we ask that any validated vulnerability +in this functionality be reported through their +[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). + +## Vulnerability Disclosure Program + +Our Vulnerability Program Guidelines are defined on our +[HackerOne program page](https://hackerone.com/anthropic-vdp). \ No newline at end of file diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 3b2ad42c8..447c9e0bd 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp-bom @@ -27,12 +27,33 @@ + + io.modelcontextprotocol.sdk + mcp-core + ${project.version} + + + io.modelcontextprotocol.sdk mcp ${project.version} + + + io.modelcontextprotocol.sdk + mcp-json + ${project.version} + + + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + ${project.version} + + io.modelcontextprotocol.sdk diff --git a/mcp-core/pom.xml b/mcp-core/pom.xml new file mode 100644 index 000000000..9e23ffd79 --- /dev/null +++ b/mcp-core/pom.xml @@ -0,0 +1,235 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-core + jar + Java MCP SDK Core + Core classes of the Java SDK implementation of the Model Context Protocol, enabling seamless integration with language models and AI tools + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + + + + io.modelcontextprotocol.sdk + mcp-json + 0.18.0-SNAPSHOT + + + + org.slf4j + slf4j-api + ${slf4j-api.version} + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + + io.projectreactor + reactor-core + + + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + provided + + + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + test + + + + org.springframework + spring-webmvc + ${springframework.version} + test + + + + + io.projectreactor.netty + reactor-netty-http + test + + + + + org.springframework + spring-context + ${springframework.version} + test + + + + org.springframework + spring-test + ${springframework.version} + test + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + + + io.projectreactor + reactor-test + test + + + org.testcontainers + junit-jupiter + ${testcontainers.version} + test + + + + org.awaitility + awaitility + ${awaitility.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + + + + + com.google.code.gson + gson + 2.10.1 + test + + + + + diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java new file mode 100644 index 000000000..07d86f40e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -0,0 +1,358 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.context.ContextView; + +/** + * Handles the protocol initialization phase between client and server + * + *

+ * The initialization phase MUST be the first interaction between client and server. + * During this phase, the client and server perform the following operations: + *

    + *
  • Establish protocol version compatibility
  • + *
  • Exchange and negotiate capabilities
  • + *
  • Share implementation details
  • + *
+ * + * Client Initialization Process + *

+ * The client MUST initiate this phase by sending an initialize request containing: + *

    + *
  • Protocol version supported
  • + *
  • Client capabilities
  • + *
  • Client implementation information
  • + *
+ * + *

+ * After successful initialization, the client MUST send an initialized notification to + * indicate it is ready to begin normal operations. + * + * Server Response + *

+ * The server MUST respond with its own capabilities and information. + * + * Protocol Version Negotiation + *

+ * In the initialize request, the client MUST send a protocol version it supports. This + * SHOULD be the latest version supported by the client. + * + *

+ * If the server supports the requested protocol version, it MUST respond with the same + * version. Otherwise, the server MUST respond with another protocol version it supports. + * This SHOULD be the latest version supported by the server. + * + *

+ * If the client does not support the version in the server's response, it SHOULD + * disconnect. + * + * Request Restrictions + *

+ * Important: The following restrictions apply during initialization: + *

    + *
  • The client SHOULD NOT send requests other than pings before the server has + * responded to the initialize request
  • + *
  • The server SHOULD NOT send requests other than pings and logging before receiving + * the initialized notification
  • + *
+ */ +class LifecycleInitializer { + + private static final Logger logger = LoggerFactory.getLogger(LifecycleInitializer.class); + + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Function sessionSupplier; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + private List protocolVersions; + + private final AtomicReference initializationRef = new AtomicReference<>(); + + /** + * The max timeout to await for the client-server connection to be initialized. + */ + private final Duration initializationTimeout; + + /** + * Post-initialization hook to perform additional operations after every successful + * initialization. + */ + private final Function> postInitializationHook; + + public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + List protocolVersions, Duration initializationTimeout, + Function sessionSupplier, + Function> postInitializationHook) { + + Assert.notNull(sessionSupplier, "Session supplier must not be null"); + Assert.notNull(clientCapabilities, "Client capabilities must not be null"); + Assert.notNull(clientInfo, "Client info must not be null"); + Assert.notEmpty(protocolVersions, "Protocol versions must not be empty"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + Assert.notNull(postInitializationHook, "Post-initialization hook must not be null"); + + this.sessionSupplier = sessionSupplier; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions)); + this.initializationTimeout = initializationTimeout; + this.postInitializationHook = postInitializationHook; + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + /** + * Represents the initialization state of the MCP client. + */ + interface Initialization { + + /** + * Returns the MCP client session that is used to communicate with the server. + * This session is established during the initialization process and is used for + * sending requests and notifications. + * @return The MCP client session + */ + McpClientSession mcpSession(); + + /** + * Returns the result of the MCP initialization process. This result contains + * information about the protocol version, capabilities, server info, and + * instructions provided by the server during the initialization phase. + * @return The result of the MCP initialization process + */ + McpSchema.InitializeResult initializeResult(); + + } + + private static class DefaultInitialization implements Initialization { + + /** + * A sink that emits the result of the MCP initialization process. It allows + * subscribers to wait for the initialization to complete. + */ + private final Sinks.One initSink; + + /** + * Holds the result of the MCP initialization process. It is used to cache the + * result for future requests. + */ + private final AtomicReference result; + + /** + * Holds the MCP client session that is used to communicate with the server. It is + * set during the initialization process and used for sending requests and + * notifications. + */ + private final AtomicReference mcpClientSession; + + private DefaultInitialization() { + this.initSink = Sinks.one(); + this.result = new AtomicReference<>(); + this.mcpClientSession = new AtomicReference<>(); + } + + // --------------------------------------------------- + // Public access for mcpSession and initializeResult because they are + // used in by the McpAsyncClient. + // ---------------------------------------------------- + public McpClientSession mcpSession() { + return this.mcpClientSession.get(); + } + + public McpSchema.InitializeResult initializeResult() { + return this.result.get(); + } + + // --------------------------------------------------- + // Private accessors used internally by the LifecycleInitializer to set the MCP + // client session and complete the initialization process. + // --------------------------------------------------- + private void setMcpClientSession(McpClientSession mcpClientSession) { + this.mcpClientSession.set(mcpClientSession); + } + + private Mono await() { + return this.initSink.asMono(); + } + + private void complete(McpSchema.InitializeResult initializeResult) { + // inform all the subscribers waiting for the initialization + this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void cacheResult(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + } + + private void error(Throwable t) { + this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void close() { + this.mcpSession().close(); + } + + private Mono closeGracefully() { + return this.mcpSession().closeGracefully(); + } + + } + + public boolean isInitialized() { + return this.currentInitializationResult() != null; + } + + public McpSchema.InitializeResult currentInitializationResult() { + DefaultInitialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; + } + + /** + * Hook to handle exceptions that occur during the MCP transport session. + *

+ * If the exception is a {@link McpTransportSessionNotFoundException}, it indicates + * that the session was not found, and we should re-initialize the client. + *

+ * @param t The exception to handle + */ + public void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + DefaultInitialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering + // the implicit initialization step. + this.withInitialization("re-initializing", result -> Mono.empty()).subscribe(); + } + } + + /** + * Utility method to ensure the initialization is established before executing an + * operation. + * @param The type of the result Mono + * @param actionName The action to perform when the client is initialized + * @param operation The operation to execute when the client is initialized + * @return A Mono that completes with the result of the operation + */ + public Mono withInitialization(String actionName, Function> operation) { + return Mono.deferContextual(ctx -> { + DefaultInitialization newInit = new DefaultInitialization(); + DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit); + + boolean needsToInitialize = previous == null; + logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); + + Mono initializationJob = needsToInitialize + ? this.doInitialize(newInit, this.postInitializationHook, ctx) : previous.await(); + + return initializationJob.map(initializeResult -> this.initializationRef.get()) + .timeout(this.initializationTimeout) + .onErrorResume(ex -> { + this.initializationRef.compareAndSet(newInit, null); + return Mono.error(new RuntimeException("Client failed to initialize " + actionName, ex)); + }) + .flatMap(res -> operation.apply(res) + .contextWrite(c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + res.initializeResult().protocolVersion()))); + }); + } + + private Mono doInitialize(DefaultInitialization initialization, + Function> postInitOperation, ContextView ctx) { + + initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); + + McpClientSession mcpClientSession = initialization.mcpSession(); + + String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion, + this.clientCapabilities, this.clientInfo); + + Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, + initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF); + + return result.flatMap(initializeResult -> { + logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", + initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), + initializeResult.instructions()); + + if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { + return Mono.error(McpError.builder(-32602) + .message("Unsupported protocol version") + .data("Unsupported protocol version from the server: " + initializeResult.protocolVersion()) + .build()); + } + + return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .contextWrite( + c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, initializeResult.protocolVersion())) + .thenReturn(initializeResult); + }).flatMap(initializeResult -> { + initialization.cacheResult(initializeResult); + return postInitOperation.apply(initialization).thenReturn(initializeResult); + }).doOnNext(initialization::complete).onErrorResume(ex -> { + initialization.error(ex); + return Mono.error(ex); + }); + } + + /** + * Closes the current initialization if it exists. + */ + public void close() { + DefaultInitialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + } + + /** + * Gracefully closes the current initialization if it exists. + * @return A Mono that completes when the connection is closed + */ + public Mono closeGracefully() { + return Mono.defer(() -> { + DefaultInitialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose; + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java similarity index 53% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index b301aa93a..e6a09cd08 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -1,28 +1,33 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.client; import java.time.Duration; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.DefaultMcpSession; -import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler; -import io.modelcontextprotocol.spec.DefaultMcpSession.RequestHandler; -import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; @@ -30,14 +35,12 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; /** * The Model Context Protocol (MCP) client implementation that provides asynchronous @@ -71,32 +74,39 @@ * * @author Dariusz Jędrzejczyk * @author Christian Tzolov + * @author Jihoon Kim + * @author Anurag Pant * @see McpClient * @see McpSchema - * @see DefaultMcpSession + * @see McpClientSession + * @see McpClientTransport */ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeRef VOID_TYPE_REFERENCE = new TypeRef<>() { }; - protected final Sinks.One initializedSink = Sinks.one(); + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { + }; - private AtomicBoolean initialized = new AtomicBoolean(false); + public static final TypeRef PAGINATED_REQUEST_TYPE_REF = new TypeRef<>() { + }; - /** - * The max timeout to await for the client-server connection to be initialized. - * Usually x2 the request timeout. // TODO should we make it configurable? - */ - private final Duration initializationTimeout; + public static final TypeRef INITIALIZE_RESULT_TYPE_REF = new TypeRef<>() { + }; - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final DefaultMcpSession mcpSession; + public static final TypeRef CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeRef<>() { + }; + + public static final TypeRef LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; + + public static final TypeRef PROGRESS_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; + + public static final String NEGOTIATED_PROTOCOL_VERSION = "io.modelcontextprotocol.client.negotiated-protocol-version"; /** * Client capabilities. @@ -108,16 +118,6 @@ public class McpAsyncClient { */ private final McpSchema.Implementation clientInfo; - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - /** * Roots define the boundaries of where servers can operate within the filesystem, * allowing them to understand which directories and files they have access to. @@ -136,37 +136,74 @@ public class McpAsyncClient { */ private Function> samplingHandler; + /** + * MCP provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to maintain + * control over user interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured data from users + * with optional JSON schemas to validate responses. + */ + private Function> elicitationHandler; + /** * Client transport implementation. */ - private final McpTransport transport; + private final McpClientTransport transport; + + /** + * The lifecycle initializer that manages the client-server connection initialization. + */ + private final LifecycleInitializer initializer; /** - * Supported protocol versions. + * JSON schema validator to use for validating tool responses against output schemas. */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + private final JsonSchemaValidator jsonSchemaValidator; + + /** + * Cached tool output schemas. + */ + private final ConcurrentHashMap> toolsOutputSchemaCache; + + /** + * Whether to enable automatic schema caching during callTool operations. + */ + private final boolean enableCallToolSchemaCaching; /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. * @param transport the transport to use. * @param requestTimeout the session request-response timeout. - * @param features the MCP Client supported features. + * @param initializationTimeout the max timeout to await for the client-server + * @param jsonSchemaValidator the JSON schema validator to use for validating tool + * @param features the MCP Client supported features. responses against output + * schemas. */ - McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, McpClientFeatures.Async features) { + McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, + JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); this.clientInfo = features.clientInfo(); this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); - this.initializationTimeout = requestTimeout.multipliedBy(2); + this.jsonSchemaValidator = jsonSchemaValidator; + this.toolsOutputSchemaCache = new ConcurrentHashMap<>(); + this.enableCallToolSchemaCaching = features.enableCallToolSchemaCaching(); // Request Handlers Map> requestHandlers = new HashMap<>(); + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, params -> { + logger.debug("Received ping: {}", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); + return Mono.just(Map.of()); + }); + // Roots List Request Handler if (this.clientCapabilities.roots() != null) { requestHandlers.put(McpSchema.METHOD_ROOTS_LIST, rootsListRequestHandler()); @@ -175,12 +212,23 @@ public class McpAsyncClient { // Sampling Handler if (this.clientCapabilities.sampling() != null) { if (features.samplingHandler() == null) { - throw new McpError("Sampling handler must not be null when client capabilities include sampling"); + throw new IllegalArgumentException( + "Sampling handler must not be null when client capabilities include sampling"); } this.samplingHandler = features.samplingHandler(); requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); } + // Elicitation Handler + if (this.clientCapabilities.elicitation() != null) { + if (features.elicitationHandler() == null) { + throw new IllegalArgumentException( + "Elicitation handler must not be null when client capabilities include elicitation"); + } + this.elicitationHandler = features.elicitationHandler(); + requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -207,6 +255,18 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, asyncResourcesChangeNotificationHandler(resourcesChangeConsumersFinal)); + // Resources Update Notification + List, Mono>> resourcesUpdateConsumersFinal = new ArrayList<>(); + resourcesUpdateConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Resources updated: {}", notification))); + + if (!Utils.isEmpty(features.resourcesUpdateConsumers())) { + resourcesUpdateConsumersFinal.addAll(features.resourcesUpdateConsumers()); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, + asyncResourcesUpdatedNotificationHandler(resourcesUpdateConsumersFinal)); + // Prompts Change Notification List, Mono>> promptsChangeConsumersFinal = new ArrayList<>(); promptsChangeConsumersFinal @@ -226,8 +286,49 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new DefaultMcpSession(requestTimeout, transport, requestHandlers, notificationHandlers); + // Utility Progress Notification + List>> progressConsumersFinal = new ArrayList<>(); + progressConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification))); + if (!Utils.isEmpty(features.progressConsumers())) { + progressConsumersFinal.addAll(features.progressConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, + asyncProgressNotificationHandler(progressConsumersFinal)); + + Function> postInitializationHook = init -> { + + if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) { + return Mono.empty(); + } + + return this.listToolsInternal(init, McpSchema.FIRST_PAGE).doOnNext(listToolsResult -> { + listToolsResult.tools() + .forEach(tool -> logger.debug("Tool {} schema: {}", tool.name(), tool.outputSchema())); + if (enableCallToolSchemaCaching && listToolsResult.tools() != null) { + // Cache tools output schema + listToolsResult.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }).then(); + }; + + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), + initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers, con -> con.contextWrite(ctx)), + postInitializationHook); + + this.transport.setExceptionHandler(this.initializer::handleException); + } + /** + * Get the current initialization result. + * @return the initialization result. + */ + public McpSchema.InitializeResult getCurrentInitializationResult() { + return this.initializer.currentInitializationResult(); } /** @@ -235,7 +336,18 @@ public class McpAsyncClient { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); + return initializeResult != null ? initializeResult.capabilities() : null; + } + + /** + * Get the server instructions that provide guidance to the client on how to interact + * with this server. + * @return The server instructions + */ + public String getServerInstructions() { + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); + return initializeResult != null ? initializeResult.instructions() : null; } /** @@ -243,7 +355,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); + return initializeResult != null ? initializeResult.serverInfo() : null; } /** @@ -251,7 +364,7 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialized.get(); + return this.initializer.isInitialized(); } /** @@ -274,7 +387,8 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - this.mcpSession.close(); + this.initializer.close(); + this.transport.close(); } /** @@ -282,14 +396,20 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return Mono.defer(() -> { + return this.initializer.closeGracefully().then(transport.closeGracefully()); + }); } // -------------------------- // Initialization // -------------------------- + /** - * The initialization phase MUST be the first interaction between client and server. + * The initialization phase should be the first interaction between client and server. + * The client will ensure it happens in case it has not been explicitly called and in + * case of transport session invalidation. + *

* During this phase, the client and server: *

    *
  • Establish protocol version compatibility
  • @@ -300,69 +420,23 @@ public Mono closeGracefully() { * The client MUST initiate this phase by sending an initialize request containing: * The protocol version the client supports, client's capabilities and clients * implementation information. - *

    + *

    * The server MUST respond with its own capabilities and information. - *

    + *

    * After successful initialization, the client MUST send an initialized notification * to indicate it is ready to begin normal operations. * @return the initialize result. * @see MCP * Initialization Spec + *

    */ public Mono initialize() { - - String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off - latestVersion, - this.clientCapabilities, - this.clientInfo); // @formatter:on - - Mono result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, new TypeReference() { - }); - - return result.flatMap(initializeResult -> { - - this.serverCapabilities = initializeResult.capabilities(); - this.serverInfo = initializeResult.serverInfo(); - - logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", - initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), - initializeResult.instructions()); - - if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { - return Mono.error(new McpError( - "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); - } - - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> { - this.initialized.set(true); - this.initializedSink.tryEmitValue(initializeResult); - }).thenReturn(initializeResult); - }); - } - - /** - * Utility method to handle the common pattern of checking initialization before - * executing an operation. - * @param The type of the result Mono - * @param actionName The action to perform if the client is initialized - * @param operation The operation to execute if the client is initialized - * @return A Mono that completes with the result of the operation - */ - private Mono withInitializationCheck(String actionName, - Function> operation) { - return this.initializedSink.asMono() - .timeout(this.initializationTimeout) - .onErrorResume(TimeoutException.class, - ex -> Mono.error(new McpError("Client must be initialized before " + actionName))) - .flatMap(operation); + return this.initializer.withInitialization("by explicit API call", init -> Mono.just(init.initializeResult())); } // -------------------------- - // Basic Utilites + // Basic Utilities // -------------------------- /** @@ -370,14 +444,14 @@ private Mono withInitializationCheck(String actionName, * @return A Mono that completes with the server's ping response */ public Mono ping() { - return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession - .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - })); + return this.initializer.withInitialization("pinging the server", + init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- // Roots // -------------------------- + /** * Adds a new root to the client's root list. * @param root The root to add. @@ -386,15 +460,15 @@ public Mono ping() { public Mono addRoot(Root root) { if (root == null) { - return Mono.error(new McpError("Root must not be null")); + return Mono.error(new IllegalArgumentException("Root must not be null")); } if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with roots capabilities")); } if (this.roots.containsKey(root.uri())) { - return Mono.error(new McpError("Root with uri '" + root.uri() + "' already exists")); + return Mono.error(new IllegalStateException("Root with uri '" + root.uri() + "' already exists")); } this.roots.put(root.uri(), root); @@ -420,11 +494,11 @@ public Mono addRoot(Root root) { public Mono removeRoot(String rootUri) { if (rootUri == null) { - return Mono.error(new McpError("Root uri must not be null")); + return Mono.error(new IllegalArgumentException("Root uri must not be null")); } if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with roots capabilities")); } Root removed = this.roots.remove(rootUri); @@ -442,7 +516,7 @@ public Mono removeRoot(String rootUri) { } return Mono.empty(); } - return Mono.error(new McpError("Root with uri '" + rootUri + "' not found")); + return Mono.error(new IllegalStateException("Root with uri '" + rootUri + "' not found")); } /** @@ -452,16 +526,14 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.withInitializationCheck("sending roots list changed notification", - initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); + return this.initializer.withInitialization("sending roots list changed notification", + init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } private RequestHandler rootsListRequestHandler() { return params -> { @SuppressWarnings("unused") - McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, PAGINATED_REQUEST_TYPE_REF); List roots = this.roots.values().stream().toList(); @@ -474,21 +546,31 @@ private RequestHandler rootsListRequestHandler() { // -------------------------- private RequestHandler samplingCreateMessageHandler() { return params -> { - McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, CREATE_MESSAGE_REQUEST_TYPE_REF); return this.samplingHandler.apply(request); }; } + // -------------------------- + // Elicitation + // -------------------------- + private RequestHandler elicitationCreateHandler() { + return params -> { + ElicitRequest request = transport.unmarshalFrom(params, new TypeRef<>() { + }); + + return this.elicitationHandler.apply(request); + }; + } + // -------------------------- // Tools // -------------------------- - private static final TypeReference CALL_TOOL_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CALL_TOOL_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_TOOLS_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -503,20 +585,57 @@ private RequestHandler samplingCreateMessageHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.withInitializationCheck("calling tools", initializedResult -> { - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server does not provide tools capability")); + return this.initializer.withInitialization("calling tool", init -> { + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF) + .flatMap(result -> Mono.just(validateToolResult(callToolRequest.name(), result))); }); } + private McpSchema.CallToolResult validateToolResult(String toolName, McpSchema.CallToolResult result) { + + if (!this.enableCallToolSchemaCaching || result == null || result.isError() == Boolean.TRUE) { + // if tool schema caching is disabled or tool call resulted in an error - skip + // validation and return the result as it is + return result; + } + + Map optOutputSchema = this.toolsOutputSchemaCache.get(toolName); + + if (optOutputSchema == null) { + logger.warn( + "Calling a tool with no outputSchema is not expected to return result with structured content, but got: {}", + result.structuredContent()); + return result; + } + + // Validate the tool output against the cached output schema + var validation = this.jsonSchemaValidator.validate(optOutputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + throw new IllegalArgumentException("Tool call result validation failed: " + validation.errorMessage()); + } + + return result; + } + /** * Retrieves the list of all tools provided by the server. - * @return A Mono that emits the list of tools result. + * @return A Mono that emits the list of all tools result */ public Mono listTools() { - return this.listTools(null); + return this.listTools(McpSchema.FIRST_PAGE).expand(result -> { + String next = result.nextCursor(); + return (next != null && !next.isEmpty()) ? this.listTools(next) : Mono.empty(); + }).reduce(new McpSchema.ListToolsResult(new ArrayList<>(), null), (allToolsResult, result) -> { + allToolsResult.tools().addAll(result.tools()); + return allToolsResult; + }).map(result -> new McpSchema.ListToolsResult(Collections.unmodifiableList(result.tools()), null)); } /** @@ -525,13 +644,26 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.withInitializationCheck("listing tools", initializedResult -> { - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server does not provide tools capability")); - } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); - }); + return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor)); + } + + private Mono listToolsInternal(Initialization init, String cursor) { + + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF) + .doOnNext(result -> { + if (this.enableCallToolSchemaCaching && result.tools() != null) { + // Cache tools output schema + result.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }); } private NotificationHandler asyncToolsChangeNotificationHandler( @@ -551,25 +683,31 @@ private NotificationHandler asyncToolsChangeNotificationHandler( // Resources // -------------------------- - private static final TypeReference LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCES_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef READ_RESOURCE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeRef<>() { }; /** * Retrieves the list of all resources provided by the server. Resources represent any * kind of UTF-8 encoded data that an MCP server makes available to clients, such as * database records, API responses, log files, and more. - * @return A Mono that completes with the list of resources result. + * @return A Mono that completes with the list of all resources result * @see McpSchema.ListResourcesResult * @see #readResource(McpSchema.Resource) */ public Mono listResources() { - return this.listResources(null); + return this.listResources(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? this.listResources(result.nextCursor()) : Mono.empty()) + .reduce(new McpSchema.ListResourcesResult(new ArrayList<>(), null), (allResourcesResult, result) -> { + allResourcesResult.resources().addAll(result.resources()); + return allResourcesResult; + }) + .map(result -> new McpSchema.ListResourcesResult(Collections.unmodifiableList(result.resources()), null)); } /** @@ -582,12 +720,13 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.withInitializationCheck("listing resources", initializedResult -> { - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return this.initializer.withInitialization("listing resources", init -> { + if (init.initializeResult().capabilities().resources() == null) { + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_RESOURCES_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCES_RESULT_TYPE_REF); }); } @@ -613,12 +752,12 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.withInitializationCheck("reading resources", initializedResult -> { - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return this.initializer.withInitialization("reading resources", init -> { + if (init.initializeResult().capabilities().resources() == null) { + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, - READ_RESOURCE_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF); }); } @@ -626,11 +765,20 @@ public Mono readResource(McpSchema.ReadResourceReq * Retrieves the list of all resource templates provided by the server. Resource * templates allow servers to expose parameterized resources using URI templates, * enabling dynamic resource access based on variable parameters. - * @return A Mono that completes with the list of resource templates result. + * @return A Mono that completes with the list of all resource templates result * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates() { - return this.listResourceTemplates(null); + return this.listResourceTemplates(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? this.listResourceTemplates(result.nextCursor()) + : Mono.empty()) + .reduce(new McpSchema.ListResourceTemplatesResult(new ArrayList<>(), null), + (allResourceTemplatesResult, result) -> { + allResourceTemplatesResult.resourceTemplates().addAll(result.resourceTemplates()); + return allResourceTemplatesResult; + }) + .map(result -> new McpSchema.ListResourceTemplatesResult( + Collections.unmodifiableList(result.resourceTemplates()), null)); } /** @@ -642,12 +790,13 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.withInitializationCheck("listing resource templates", initializedResult -> { - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return this.initializer.withInitialization("listing resource templates", init -> { + if (init.initializeResult().capabilities().resources() == null) { + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, - new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); }); } @@ -661,7 +810,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withInitializationCheck("subscribing to resources", initializedResult -> this.mcpSession + return this.initializer.withInitialization("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -675,7 +824,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withInitializationCheck("unsubscribing from resources", initializedResult -> this.mcpSession + return this.initializer.withInitialization("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -690,23 +839,47 @@ private NotificationHandler asyncResourcesChangeNotificationHandler( .then()); } + private NotificationHandler asyncResourcesUpdatedNotificationHandler( + List, Mono>> resourcesUpdateConsumers) { + return params -> { + McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification = transport.unmarshalFrom(params, + new TypeRef<>() { + }); + + return readResource(new McpSchema.ReadResourceRequest(resourcesUpdatedNotification.uri())) + .flatMap(readResourceResult -> Flux.fromIterable(resourcesUpdateConsumers) + .flatMap(consumer -> consumer.apply(readResourceResult.contents())) + .onErrorResume(error -> { + logger.error("Error handling resource update notification", error); + return Mono.empty(); + }) + .then()); + }; + } + // -------------------------- // Prompts // -------------------------- - private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_PROMPTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef GET_PROMPT_RESULT_TYPE_REF = new TypeRef<>() { }; /** * Retrieves the list of all prompts provided by the server. - * @return A Mono that completes with the list of prompts result. + * @return A Mono that completes with the list of all prompts result. * @see McpSchema.ListPromptsResult * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts() { - return this.listPrompts(null); + return this.listPrompts(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? this.listPrompts(result.nextCursor()) : Mono.empty()) + .reduce(new ListPromptsResult(new ArrayList<>(), null), (allPromptsResult, result) -> { + allPromptsResult.prompts().addAll(result.prompts()); + return allPromptsResult; + }) + .map(result -> new McpSchema.ListPromptsResult(Collections.unmodifiableList(result.prompts()), null)); } /** @@ -717,7 +890,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withInitializationCheck("listing prompts", initializedResult -> this.mcpSession + return this.initializer.withInitialization("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -731,7 +904,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withInitializationCheck("getting prompts", initializedResult -> this.mcpSession + return this.initializer.withInitialization("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -754,8 +927,7 @@ private NotificationHandler asyncLoggingNotificationHandler( return params -> { McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, - new TypeReference() { - }); + LOGGING_MESSAGE_NOTIFICATION_TYPE_REF); return Flux.fromIterable(loggingConsumers) .flatMap(consumer -> consumer.apply(loggingMessageNotification)) @@ -771,23 +943,60 @@ private NotificationHandler asyncLoggingNotificationHandler( * @see McpSchema.LoggingLevel */ public Mono setLoggingLevel(LoggingLevel loggingLevel) { - Assert.notNull(loggingLevel, "Logging level must not be null"); + if (loggingLevel == null) { + return Mono.error(new IllegalArgumentException("Logging level must not be null")); + } - return this.withInitializationCheck("setting logging level", initializedResult -> { - String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { - }); - Map params = Map.of("level", levelName); - return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); + return this.initializer.withInitialization("setting logging level", init -> { + if (init.initializeResult().capabilities().logging() == null) { + return Mono.error(new IllegalStateException("Server's Logging capabilities are not enabled!")); + } + var params = new McpSchema.SetLevelRequest(loggingLevel); + return init.mcpSession().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); }); } + private NotificationHandler asyncProgressNotificationHandler( + List>> progressConsumers) { + + return params -> { + McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params, + PROGRESS_NOTIFICATION_TYPE_REF); + + return Flux.fromIterable(progressConsumers) + .flatMap(consumer -> consumer.apply(progressNotification)) + .then(); + }; + } + /** * This method is package-private and used for test only. Should not be called by user * code. * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; + this.initializer.setProtocolVersions(protocolVersions); + } + + // -------------------------- + // Completions + // -------------------------- + private static final TypeRef COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeRef<>() { + }; + + /** + * Sends a completion/complete request to generate value suggestions based on a given + * reference and argument. This is typically used to provide auto-completion options + * for user input fields. + * @param completeRequest The request containing the prompt or resource reference and + * argument for which to generate completions. + * @return A Mono that completes with the result containing completion suggestions. + * @see McpSchema.CompleteRequest + * @see McpSchema.CompleteResult + */ + public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { + return this.initializer.withInitialization("complete completions", init -> init.mcpSession() + .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java similarity index 65% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index 7ab01b70c..c9989f832 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -4,25 +4,30 @@ package io.modelcontextprotocol.client; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + /** * Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that * enables AI models to interact with external tools and resources through a standardized @@ -70,6 +75,7 @@ * .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> System.out.println("Resources updated: " + resources))) * .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> System.out.println("Prompts updated: " + prompts))) * .loggingConsumer(message -> Mono.fromRunnable(() -> System.out.println("Log message: " + message))) + * .resourcesUpdateConsumer(resourceContents -> Mono.fromRunnable(() -> System.out.println("Resources contents updated: " + resourceContents))) * .build(); * } * @@ -95,6 +101,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Anurag Pant * @see McpAsyncClient * @see McpSyncClient * @see McpTransport @@ -114,7 +121,7 @@ public interface McpClient { * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null */ - static SyncSpec sync(ClientMcpTransport transport) { + static SyncSpec sync(McpClientTransport transport) { return new SyncSpec(transport); } @@ -131,7 +138,7 @@ static SyncSpec sync(ClientMcpTransport transport) { * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null */ - static AsyncSpec async(ClientMcpTransport transport) { + static AsyncSpec async(McpClientTransport transport) { return new AsyncSpec(transport); } @@ -153,13 +160,15 @@ static AsyncSpec async(ClientMcpTransport transport) { */ class SyncSpec { - private final ClientMcpTransport transport; + private final McpClientTransport transport; private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "0.15.0"); private final Map roots = new HashMap<>(); @@ -167,13 +176,25 @@ class SyncSpec { private final List>> resourcesChangeConsumers = new ArrayList<>(); + private final List>> resourcesUpdateConsumers = new ArrayList<>(); + private final List>> promptsChangeConsumers = new ArrayList<>(); private final List> loggingConsumers = new ArrayList<>(); + private final List> progressConsumers = new ArrayList<>(); + private Function samplingHandler; - private SyncSpec(ClientMcpTransport transport) { + private Function elicitationHandler; + + private Supplier contextProvider = () -> McpTransportContext.EMPTY; + + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; } @@ -193,6 +214,18 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initialization + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public SyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -269,6 +302,21 @@ public SyncSpec sampling(Function sam return this; } + /** + * Sets a custom elicitation handler for processing elicitation message requests. + * The elicitation handler can modify or validate messages before they are sent to + * the server, enabling custom processing logic. + * @param elicitationHandler A function that processes elicitation requests and + * returns results. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationHandler is null + */ + public SyncSpec elicitation(Function elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -299,6 +347,22 @@ public SyncSpec resourcesChangeConsumer(Consumer> resou return this; } + /** + * Adds a consumer to be notified when a specific resource is updated. This allows + * the client to react to changes in individual resources, such as updates to + * their content or metadata. + * @param resourcesUpdateConsumer A consumer function that processes the updated + * resource and returns a Mono indicating the completion of the processing. Must + * not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException If the resourcesUpdateConsumer is null. + */ + public SyncSpec resourcesUpdateConsumer(Consumer> resourcesUpdateConsumer) { + Assert.notNull(resourcesUpdateConsumer, "Resources update consumer must not be null"); + this.resourcesUpdateConsumers.add(resourcesUpdateConsumer); + return this; + } + /** * Adds a consumer to be notified when the available prompts change. This allows * the client to react to changes in the server's prompt templates, such as new @@ -342,6 +406,78 @@ public SyncSpec loggingConsumers(List progressConsumer) { + Assert.notNull(progressConsumer, "Progress consumer must not be null"); + this.progressConsumers.add(progressConsumer); + return this; + } + + /** + * Adds a multiple consumers to be notified of progress notifications from the + * server. This allows the client to track long-running operations and provide + * feedback to users. + * @param progressConsumers A list of consumers that receives progress + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public SyncSpec progressConsumers(List> progressConsumers) { + Assert.notNull(progressConsumers, "Progress consumers must not be null"); + this.progressConsumers.addAll(progressConsumers); + return this; + } + + /** + * Add a provider of {@link McpTransportContext}, providing a context before + * calling any client operation. This allows to extract thread-locals and hand + * them over to the underlying transport. + *

    + * There is no direct equivalent in {@link AsyncSpec}. To achieve the same result, + * append {@code contextWrite(McpTransportContext.KEY, context)} to any + * {@link McpAsyncClient} call. + * @param contextProvider A supplier to create a context + * @return This builder for method chaining + */ + public SyncSpec transportContextProvider(Supplier contextProvider) { + this.contextProvider = contextProvider; + return this; + } + + /** + * Add a {@link JsonSchemaValidator} to validate the JSON structure of the + * structured output. + * @param jsonSchemaValidator A validator to validate the JSON structure of the + * structured output. Must not be null. + * @return This builder for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public SyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -349,12 +485,15 @@ public SyncSpec loggingConsumers(List roots = new HashMap<>(); @@ -391,13 +532,23 @@ class AsyncSpec { private final List, Mono>> resourcesChangeConsumers = new ArrayList<>(); + private final List, Mono>> resourcesUpdateConsumers = new ArrayList<>(); + private final List, Mono>> promptsChangeConsumers = new ArrayList<>(); private final List>> loggingConsumers = new ArrayList<>(); + private final List>> progressConsumers = new ArrayList<>(); + private Function> samplingHandler; - private AsyncSpec(ClientMcpTransport transport) { + private Function> elicitationHandler; + + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; } @@ -417,6 +568,18 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initialization + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public AsyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -493,6 +656,21 @@ public AsyncSpec sampling(Function> elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -524,6 +702,23 @@ public AsyncSpec resourcesChangeConsumer( return this; } + /** + * Adds a consumer to be notified when a specific resource is updated. This allows + * the client to react to changes in individual resources, such as updates to + * their content or metadata. + * @param resourcesUpdateConsumer A consumer function that processes the updated + * resource and returns a Mono indicating the completion of the processing. Must + * not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException If the resourcesUpdateConsumer is null. + */ + public AsyncSpec resourcesUpdateConsumer( + Function, Mono> resourcesUpdateConsumer) { + Assert.notNull(resourcesUpdateConsumer, "Resources update consumer must not be null"); + this.resourcesUpdateConsumers.add(resourcesUpdateConsumer); + return this; + } + /** * Adds a consumer to be notified when the available prompts change. This allows * the client to react to changes in the server's prompt templates, such as new @@ -568,16 +763,76 @@ public AsyncSpec loggingConsumers( return this; } + /** + * Adds a consumer to be notified of progress notifications from the server. This + * allows the client to track long-running operations and provide feedback to + * users. + * @param progressConsumer A consumer that receives progress notifications. Must + * not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public AsyncSpec progressConsumer(Function> progressConsumer) { + Assert.notNull(progressConsumer, "Progress consumer must not be null"); + this.progressConsumers.add(progressConsumer); + return this; + } + + /** + * Adds a multiple consumers to be notified of progress notifications from the + * server. This allows the client to track long-running operations and provide + * feedback to users. + * @param progressConsumers A list of consumers that receives progress + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public AsyncSpec progressConsumers( + List>> progressConsumers) { + Assert.notNull(progressConsumers, "Progress consumers must not be null"); + this.progressConsumers.addAll(progressConsumers); + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool responses against + * output schemas. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public AsyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpAsyncClient} with the provided configurations * or sensible defaults. * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { - return new McpAsyncClient(this.transport, this.requestTimeout, + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, + jsonSchemaValidator, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, - this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, + this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, + this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java similarity index 61% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 284b93f88..127d53337 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -59,14 +59,21 @@ class McpClientFeatures { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, + List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + List>> progressConsumers, + Function> samplingHandler, + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -76,29 +83,58 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, + List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + List>> progressConsumers, + Function> samplingHandler, + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); this.resourcesChangeConsumers = resourcesChangeConsumers != null ? resourcesChangeConsumers : List.of(); + this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. + */ + public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, + Map roots, + List, Mono>> toolsChangeConsumers, + List, Mono>> resourcesChangeConsumers, + List, Mono>> resourcesUpdateConsumers, + List, Mono>> promptsChangeConsumers, + List>> loggingConsumers, + Function> samplingHandler, + Function> elicitationHandler) { + this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, + resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, + elicitationHandler, false); } /** @@ -122,8 +158,13 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } - List, Mono>> promptsChangeConsumers = new ArrayList<>(); + List, Mono>> resourcesUpdateConsumers = new ArrayList<>(); + for (Consumer> consumer : syncSpec.resourcesUpdateConsumers()) { + resourcesUpdateConsumers.add(r -> Mono.fromRunnable(() -> consumer.accept(r)) + .subscribeOn(Schedulers.boundedElastic())); + } + List, Mono>> promptsChangeConsumers = new ArrayList<>(); for (Consumer> consumer : syncSpec.promptsChangeConsumers()) { promptsChangeConsumers.add(p -> Mono.fromRunnable(() -> consumer.accept(p)) .subscribeOn(Schedulers.boundedElastic())); @@ -135,12 +176,24 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } + List>> progressConsumers = new ArrayList<>(); + for (Consumer consumer : syncSpec.progressConsumers()) { + progressConsumers.add(l -> Mono.fromRunnable(() -> consumer.accept(l)) + .subscribeOn(Schedulers.boundedElastic())); + } + Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); + + Function> elicitationHandler = r -> Mono + .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()); + return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), - toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, - samplingHandler); + toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, + loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, + syncSpec.enableCallToolSchemaCaching); } } @@ -155,14 +208,21 @@ public static Async fromSync(Sync syncSpec) { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, + List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + List> progressConsumers, + Function samplingHandler, + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -171,30 +231,59 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param roots the roots. * @param toolsChangeConsumers the tools change consumers. * @param resourcesChangeConsumers the resources change consumers. + * @param resourcesUpdateConsumers the resource update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, + List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + List> progressConsumers, + Function samplingHandler, + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); this.resourcesChangeConsumers = resourcesChangeConsumers != null ? resourcesChangeConsumers : List.of(); + this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. + */ + public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, + Map roots, List>> toolsChangeConsumers, + List>> resourcesChangeConsumers, + List>> resourcesUpdateConsumers, + List>> promptsChangeConsumers, + List> loggingConsumers, + Function samplingHandler, + Function elicitationHandler) { + this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, + resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, + elicitationHandler, false); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java similarity index 67% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index e5d964b7a..7fdaa8941 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,16 +5,19 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.function.Supplier; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; /** * A synchronous client implementation for the Model Context Protocol (MCP) that wraps an @@ -47,6 +50,7 @@ * * @author Dariusz Jędrzejczyk * @author Christian Tzolov + * @author Jihoon Kim * @see McpClient * @see McpAsyncClient * @see McpSchema @@ -62,17 +66,28 @@ public class McpSyncClient implements AutoCloseable { private final McpAsyncClient delegate; + private final Supplier contextProvider; + /** * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. - * @deprecated Use {@link McpClient#sync(ClientMcpTransport)} to obtain an instance. + * @param contextProvider the supplier of context before calling any non-blocking + * operation on underlying delegate */ - @Deprecated - // TODO make the constructor package private post-deprecation - public McpSyncClient(McpAsyncClient delegate) { + McpSyncClient(McpAsyncClient delegate, Supplier contextProvider) { Assert.notNull(delegate, "The delegate can not be null"); + Assert.notNull(contextProvider, "The contextProvider can not be null"); this.delegate = delegate; + this.contextProvider = contextProvider; + } + + /** + * Get the current initialization result. + * @return the initialization result. + */ + public McpSchema.InitializeResult getCurrentInitializationResult() { + return this.delegate.getCurrentInitializationResult(); } /** @@ -83,6 +98,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() { return this.delegate.getServerCapabilities(); } + /** + * Get the server instructions that provide guidance to the client on how to interact + * with this server. + * @return The instructions + */ + public String getServerInstructions() { + return this.delegate.getServerInstructions(); + } + /** * Get the server implementation information. * @return The server implementation details @@ -91,6 +115,14 @@ public McpSchema.Implementation getServerInfo() { return this.delegate.getServerInfo(); } + /** + * Check if the client-server connection is initialized. + * @return true if the client-server connection is initialized + */ + public boolean isInitialized() { + return this.delegate.isInitialized(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities @@ -154,14 +186,14 @@ public boolean closeGracefully() { public McpSchema.InitializeResult initialize() { // TODO: block takes no argument here as we assume the async client is // configured with a requestTimeout at all times - return this.delegate.initialize().block(); + return withProvidedContext(this.delegate.initialize()).block(); } /** * Send a roots/list_changed notification. */ public void rootsListChangedNotification() { - this.delegate.rootsListChangedNotification().block(); + withProvidedContext(this.delegate.rootsListChangedNotification()).block(); } /** @@ -183,7 +215,7 @@ public void removeRoot(String rootUri) { * @return */ public Object ping() { - return this.delegate.ping().block(); + return withProvidedContext(this.delegate.ping()).block(); } // -------------------------- @@ -201,17 +233,18 @@ public Object ping() { * Boolean indicating if the execution failed (true) or succeeded (false/absent) */ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) { - return this.delegate.callTool(callToolRequest).block(); + return withProvidedContext(this.delegate.callTool(callToolRequest)).block(); + } /** * Retrieves the list of all tools provided by the server. - * @return The list of tools result containing: - tools: List of available tools, each - * with a name, description, and input schema - nextCursor: Optional cursor for + * @return The list of all tools result containing: - tools: List of available tools, + * each with a name, description, and input schema - nextCursor: Optional cursor for * pagination if more tools are available */ public McpSchema.ListToolsResult listTools() { - return this.delegate.listTools().block(); + return withProvidedContext(this.delegate.listTools()).block(); } /** @@ -222,7 +255,8 @@ public McpSchema.ListToolsResult listTools() { * pagination if more tools are available */ public McpSchema.ListToolsResult listTools(String cursor) { - return this.delegate.listTools(cursor).block(); + return withProvidedContext(this.delegate.listTools(cursor)).block(); + } // -------------------------- @@ -230,20 +264,22 @@ public McpSchema.ListToolsResult listTools(String cursor) { // -------------------------- /** - * Send a resources/list request. - * @param cursor the cursor - * @return the list of resources result. + * Retrieves the list of all resources provided by the server. + * @return The list of all resources result */ - public McpSchema.ListResourcesResult listResources(String cursor) { - return this.delegate.listResources(cursor).block(); + public McpSchema.ListResourcesResult listResources() { + return withProvidedContext(this.delegate.listResources()).block(); + } /** - * Send a resources/list request. - * @return the list of resources result. + * Retrieves a paginated list of resources provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of resources result */ - public McpSchema.ListResourcesResult listResources() { - return this.delegate.listResources().block(); + public McpSchema.ListResourcesResult listResources(String cursor) { + return withProvidedContext(this.delegate.listResources(cursor)).block(); + } /** @@ -252,7 +288,8 @@ public McpSchema.ListResourcesResult listResources() { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { - return this.delegate.readResource(resource).block(); + return withProvidedContext(this.delegate.readResource(resource)).block(); + } /** @@ -261,27 +298,30 @@ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.delegate.readResource(readResourceRequest).block(); + return withProvidedContext(this.delegate.readResource(readResourceRequest)).block(); + + } + + /** + * Retrieves the list of all resource templates provided by the server. + * @return The list of all resource templates result. + */ + public McpSchema.ListResourceTemplatesResult listResourceTemplates() { + return withProvidedContext(this.delegate.listResourceTemplates()).block(); + } /** * Resource templates allow servers to expose parameterized resources using URI * templates. Arguments may be auto-completed through the completion API. * - * Request a list of resource templates the server has. - * @param cursor the cursor - * @return the list of resource templates result. + * Retrieves a paginated list of resource templates provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor) { - return this.delegate.listResourceTemplates(cursor).block(); - } + return withProvidedContext(this.delegate.listResourceTemplates(cursor)).block(); - /** - * Request a list of resource templates the server has. - * @return the list of resource templates result. - */ - public McpSchema.ListResourceTemplatesResult listResourceTemplates() { - return this.delegate.listResourceTemplates().block(); } /** @@ -294,7 +334,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { * subscribe to. */ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - this.delegate.subscribeResource(subscribeRequest).block(); + withProvidedContext(this.delegate.subscribeResource(subscribeRequest)).block(); + } /** @@ -303,22 +344,34 @@ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { * to unsubscribe from. */ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - this.delegate.unsubscribeResource(unsubscribeRequest).block(); + withProvidedContext(this.delegate.unsubscribeResource(unsubscribeRequest)).block(); + } // -------------------------- // Prompts // -------------------------- - public ListPromptsResult listPrompts(String cursor) { - return this.delegate.listPrompts(cursor).block(); - } + /** + * Retrieves the list of all prompts provided by the server. + * @return The list of all prompts result. + */ public ListPromptsResult listPrompts() { - return this.delegate.listPrompts().block(); + return withProvidedContext(this.delegate.listPrompts()).block(); + } + + /** + * Retrieves a paginated list of prompts provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of prompts result. + */ + public ListPromptsResult listPrompts(String cursor) { + return withProvidedContext(this.delegate.listPrompts(cursor)).block(); + } public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { - return this.delegate.getPrompt(getPromptRequest).block(); + return withProvidedContext(this.delegate.getPrompt(getPromptRequest)).block(); } /** @@ -326,7 +379,29 @@ public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { * @param loggingLevel the min logging level */ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { - this.delegate.setLoggingLevel(loggingLevel).block(); + withProvidedContext(this.delegate.setLoggingLevel(loggingLevel)).block(); + + } + + /** + * Send a completion/complete request. + * @param completeRequest the completion request contains the prompt or resource + * reference and arguments for generating suggestions. + * @return the completion result containing suggested values. + */ + public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) { + return withProvidedContext(this.delegate.completeCompletion(completeRequest)).block(); + + } + + /** + * For a given action, on assembly, capture the "context" via the + * {@link #contextProvider} and store it in the Reactor context. + * @param action the action to perform + * @return the result of the action + */ + private Mono withProvidedContext(Mono action) { + return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get())); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java new file mode 100644 index 000000000..ae093316f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -0,0 +1,513 @@ +/* + * Copyright 2024 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * Server-Sent Events (SSE) implementation of the + * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE + * transport specification, using Java's HttpClient. + * + *

    + * This transport implementation establishes a bidirectional communication channel between + * client and server using SSE for server-to-client messages and HTTP POST requests for + * client-to-server messages. The transport: + *

      + *
    • Establishes an SSE connection to receive server messages
    • + *
    • Handles endpoint discovery through SSE events
    • + *
    • Manages message serialization/deserialization using Jackson
    • + *
    • Provides graceful connection termination
    • + *
    + * + *

    + * The transport supports two types of SSE events: + *

      + *
    • 'endpoint' - Contains the URL for sending client messages
    • + *
    • 'message' - Contains JSON-RPC message payload
    • + *
    + * + * @author Christian Tzolov + * @see io.modelcontextprotocol.spec.McpTransport + * @see io.modelcontextprotocol.spec.McpClientTransport + */ +public class HttpClientSseClientTransport implements McpClientTransport { + + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; + + private static final String MCP_PROTOCOL_VERSION_HEADER_NAME = "MCP-Protocol-Version"; + + private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); + + /** SSE event type for JSON-RPC messages */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + /** SSE event type for endpoint discovery */ + private static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** Default SSE endpoint path */ + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** Base URI for the MCP server */ + private final URI baseUri; + + /** SSE endpoint path */ + private final String sseEndpoint; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + + /** JSON mapper for message serialization/deserialization */ + protected McpJsonMapper jsonMapper; + + /** Flag indicating if the transport is in closing state */ + private volatile boolean isClosing = false; + + /** Holds the SSE subscription disposable */ + private final AtomicReference sseSubscription = new AtomicReference<>(); + + /** + * Sink for managing the message endpoint URI provided by the server. Stores the most + * recent endpoint URI and makes it available for outbound message processing. + */ + protected final Sinks.One messageEndpointSink = Sinks.one(); + + /** + * Customizer to modify requests before they are executed. + */ + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param jsonMapper the object mapper for JSON serialization/deserialization + * @param httpRequestCustomizer customizer for the requestBuilder before executing + * requests + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.hasText(baseUri, "baseUri must not be empty"); + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + Assert.notNull(httpClient, "httpClient must not be null"); + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); + this.baseUri = URI.create(baseUri); + this.sseEndpoint = sseEndpoint; + this.jsonMapper = jsonMapper; + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.httpRequestCustomizer = httpRequestCustomizer; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + /** + * Creates a new builder for {@link HttpClientSseClientTransport}. + * @param baseUri the base URI of the MCP server + * @return a new builder instance + */ + public static Builder builder(String baseUri) { + return new Builder().baseUri(baseUri); + } + + /** + * Builder for {@link HttpClientSseClientTransport}. + */ + public static class Builder { + + private String baseUri; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); + + private McpJsonMapper jsonMapper; + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; + + private Duration connectTimeout = Duration.ofSeconds(10); + + /** + * Creates a new builder instance. + */ + Builder() { + // Default constructor + } + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * This constructor is deprecated and will be removed or made {@code protected} or + * {@code private} in a future release. + */ + @Deprecated(forRemoval = true) + public Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the base URI. + * @param baseUri the base URI + * @return this builder + */ + Builder baseUri(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Sets the JSON mapper implementation to use for serialization/deserialization. + * @param jsonMapper the JSON mapper + * @return this builder + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

    + * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. + *

    + * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

    + * This overrides the customizer from + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. + *

    + * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + + /** + * Sets the connection timeout for the HTTP client. + * @param connectTimeout the connection timeout duration + * @return this builder + */ + public Builder connectTimeout(Duration connectTimeout) { + Assert.notNull(connectTimeout, "connectTimeout must not be null"); + this.connectTimeout = connectTimeout; + return this; + } + + /** + * Builds a new {@link HttpClientSseClientTransport} instance. + * @return a new transport instance + */ + public HttpClientSseClientTransport build() { + HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); + return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer); + } + + } + + @Override + public Mono connect(Function, Mono> handler) { + var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + + return Mono.deferContextual(ctx -> { + var builder = requestBuilder.copy() + .uri(uri) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) + .GET(); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); + }).flatMap(requestBuilder -> Mono.create(sink -> { + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + if (isClosing) { + return Mono.empty(); + } + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + try { + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String messageEndpointUri = responseEvent.sseEvent().data(); + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + sink.success(); + return Flux.empty(); // No further processing needed + } + else { + sink.error(new RuntimeException("Failed to handle SSE endpoint event")); + } + } + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, + responseEvent.sseEvent().data()); + sink.success(); + return Flux.just(message); + } + else { + logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent()); + sink.success(); + } + } + catch (IOException e) { + sink.error(new McpTransportException("Error processing SSE event", e)); + } + } + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + + }) + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) + .onErrorComplete(t -> { + if (!isClosing) { + logger.warn("SSE stream observed an error", t); + sink.error(t); + } + return true; + }) + .doFinally(s -> { + Disposable ref = this.sseSubscription.getAndSet(null); + if (ref != null && !ref.isDisposed()) { + ref.dispose(); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + + this.sseSubscription.set(connection); + })); + } + + /** + * Sends a JSON-RPC message to the server. + * + *

    + * This method waits for the message endpoint to be discovered before sending the + * message. The message is serialized to JSON and sent as an HTTP POST request. + * @param message the JSON-RPC message to send + * @return a Mono that completes when the message is sent + * @throws McpError if the message endpoint is not available or the wait times out + */ + @Override + public Mono sendMessage(JSONRPCMessage message) { + + return this.messageEndpointSink.asMono().flatMap(messageEndpointUri -> { + if (isClosing) { + return Mono.empty(); + } + + return this.serializeMessage(message) + .flatMap(body -> sendHttpPost(messageEndpointUri, body).handle((response, sink) -> { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + sink.error(new RuntimeException("Sending message failed with a non-OK HTTP code: " + + response.statusCode() + " - " + response.body())); + } + else { + sink.next(response); + sink.complete(); + } + })) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }); + }).then(); + + } + + private Mono serializeMessage(final JSONRPCMessage message) { + return Mono.defer(() -> { + try { + return Mono.just(jsonMapper.writeValueAsString(message)); + } + catch (IOException e) { + return Mono.error(new McpTransportException("Failed to serialize message", e)); + } + }); + } + + private Mono> sendHttpPost(final String endpoint, final String body) { + final URI requestUri = Utils.resolveUri(baseUri, endpoint); + return Mono.deferContextual(ctx -> { + var builder = this.requestBuilder.copy() + .uri(requestUri) + .header(HttpHeaders.CONTENT_TYPE, "application/json") + .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) + .POST(HttpRequest.BodyPublishers.ofString(body)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body, transportContext)); + }).flatMap(customizedBuilder -> { + var request = customizedBuilder.build(); + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }); + } + + /** + * Gracefully closes the transport connection. + * + *

    + * Sets the closing flag and disposes of the SSE subscription. This prevents new + * messages from being sent and allows ongoing operations to complete. + * @return a Mono that completes when the closing process is initiated + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + Disposable subscription = sseSubscription.get(); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); + } + }); + } + + /** + * Unmarshal data to the specified type using the configured object mapper. + * @param data the data to unmarshal + * @param typeRef the type reference for the target type + * @param the target type + * @return the unmarshalled object + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java new file mode 100644 index 000000000..0a8dff363 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -0,0 +1,832 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.time.Duration; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

    + * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

    + *

    + * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@code WebFluxSseClientTransport}. + *

    + * + * @author Christian Tzolov + * @see Streamable + * HTTP transport specification + */ +public class HttpClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static int NOT_FOUND = 404; + + public static int METHOD_NOT_ALLOWED = 405; + + public static int BAD_REQUEST = 400; + + private final McpJsonMapper jsonMapper; + + private final URI baseUri; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + + private final AtomicReference> activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, + HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, + boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.baseUri = URI.create(baseUri); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + this.httpRequestCustomizer = httpRequestCustomizer; + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); + } + + @Override + public List protocolVersions() { + return supportedProtocolVersions; + } + + public static Builder builder(String baseUri) { + return new Builder(baseUri); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (this.openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).onErrorComplete(t -> { + logger.warn("Eager connect failed ", t); + return true; + }).then(); + } + return Mono.empty(); + }); + } + + private McpTransportSession createTransportSession() { + Function> onClose = sessionId -> sessionId == null ? Mono.empty() + : createDelete(sessionId); + return new DefaultMcpTransportSession(onClose); + } + + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + + private Publisher createDelete(String sessionId) { + + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + return Mono.deferContextual(ctx -> { + var builder = this.requestBuilder.copy() + .uri(uri) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.MCP_SESSION_ID, sessionId) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .DELETE(); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null, transportContext)); + }).flatMap(requestBuilder -> { + var request = requestBuilder.build(); + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }).then(); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); + if (currentSession != null) { + return Mono.from(currentSession.closeGracefully()); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + + return Mono.deferContextual(ctx -> { + + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + + Disposable connection = Mono.deferContextual(connectionCtx -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.MCP_SESSION_ID, + transportSession.sessionId().get()); + } + + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.LAST_EVENT_ID, stream.lastId().get()); + } + + var builder = requestBuilder.uri(uri) + .header(HttpHeaders.ACCEPT, TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, + connectionCtx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .GET(); + var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); + }) + .flatMapMany( + requestBuilder -> Flux.create( + sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, + sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( + this.jsonMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), + List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); + } + } + else { + logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger + .debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and the response is 404, we consider it a + // session not found error. + logger.debug("Session not found for session ID: {}", + transportSession.sessionId().get()); + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Server Not Found. Status code:" + statusCode + + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and thre response is 404, we consider it a + // session not found error. + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Bad Request. Status code:" + statusCode + + ", response-event:" + responseEvent)); + + } + + return Flux.error(new McpTransportException( + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + }).flatMap( + jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + + } + + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { + + BodyHandler responseBodyHandler = responseInfo -> { + + String contentType = responseInfo.headers().firstValue(HttpHeaders.CONTENT_TYPE).orElse("").toLowerCase(); + + if (contentType.contains(TEXT_EVENT_STREAM)) { + // For SSE streams, use line subscriber that returns Void + logger.debug("Received SSE stream response, using line subscriber"); + return ResponseSubscribers.sseToBodySubscriber(responseInfo, sink); + } + else if (contentType.contains(APPLICATION_JSON)) { + // For JSON responses and others, use string subscriber + logger.debug("Received response, using string subscriber"); + return ResponseSubscribers.aggregateBodySubscriber(responseInfo, sink); + } + + logger.debug("Received Bodyless response, using discarding subscriber"); + return ResponseSubscribers.bodilessBodySubscriber(responseInfo, sink); + }; + + return responseBodyHandler; + + } + + public String toString(McpSchema.JSONRPCMessage message) { + try { + return this.jsonMapper.writeValueAsString(message); + } + catch (IOException e) { + throw new RuntimeException("Failed to serialize JSON-RPC message", e); + } + } + + public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { + return Mono.create(deliveredSink -> { + logger.debug("Sending message {}", sentMessage); + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + String jsonBody = this.toString(sentMessage); + + Disposable connection = Mono.deferContextual(ctx -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.MCP_SESSION_ID, + transportSession.sessionId().get()); + } + + var builder = requestBuilder.uri(uri) + .header(HttpHeaders.ACCEPT, APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) + .header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON) + .header(HttpHeaders.CACHE_CONTROL, "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody, transportContext)); + }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> { + + // Create the async request with proper body subscriber selection + Mono.fromFuture(this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(responseEventSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + responseEventSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); + + })).flatMap(responseEvent -> { + if (transportSession.markInitialized( + responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + + reconnect(null).contextWrite(deliveredSink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + String contentType = responseEvent.responseInfo() + .headers() + .firstValue(HttpHeaders.CONTENT_TYPE) + .orElse("") + .toLowerCase(); + + String contentLength = responseEvent.responseInfo() + .headers() + .firstValue(HttpHeaders.CONTENT_LENGTH) + .orElse(null); + + // For empty content or HTTP code 202 (ACCEPTED), assume success + if (contentType.isBlank() || "0".equals(contentLength) || statusCode == 202) { + // if (contentType.isBlank() || "0".equals(contentLength)) { + logger.debug("No body returned for POST in session {}", sessionRepresentation); + // No content type means no response body, so we can just + // return an empty stream + deliveredSink.success(); + return Flux.empty(); + } + else if (contentType.contains(TEXT_EVENT_STREAM)) { + return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) + .flatMap(sseEvent -> { + try { + // We don't support batching ATM and probably + // won't + // since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.jsonMapper, sseEvent.data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseEvent.id()), List.of(message)); + + McpTransportStream sessionStream = new DefaultMcpTransportStream<>( + this.resumableStreams, this::reconnect); + + logger.debug("Connected stream {}", sessionStream.streamId()); + + deliveredSink.success(); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); + } + }); + } + else if (contentType.contains(APPLICATION_JSON)) { + deliveredSink.success(); + String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(data) ? data : "[empty]"); + return Mono.empty(); + } + + try { + return Mono.just(McpSchema.deserializeJsonRpcMessage(jsonMapper, data)); + } + catch (IOException e) { + return Mono.error(new McpTransportException( + "Error deserializing JSON-RPC message: " + responseEvent, e)); + } + } + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + + return Flux.error( + new RuntimeException("Unknown media type returned: " + contentType)); + } + else if (statusCode == NOT_FOUND) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id and the + // response is 404, we consider it a session not found error. + logger.debug("Session not found for session ID: {}", transportSession.sessionId().get()); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Server Not Found. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id and the + // response is 404, we consider it a session not found error. + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Bad Request. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + deliveredSink.error(t); + return true; + }) + .doFinally(s -> { + logger.debug("SendMessage finally: {}", s); + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(deliveredSink.contextView()) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); + } + + /** + * Builder for {@link HttpClientStreamableHttpTransport}. + */ + public static class Builder { + + private final String baseUri; + + private McpJsonMapper jsonMapper; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; + + private Duration connectTimeout = Duration.ofSeconds(10); + + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + */ + private Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Configure a custom {@link McpJsonMapper} implementation to use. + * @param jsonMapper instance to use + * @return the builder instance + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

    + * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. + *

    + * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

    + * This overrides the customizer from + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. + *

    + * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + + /** + * Sets the connection timeout for the HTTP client. + * @param connectTimeout the connection timeout duration + * @return this builder + */ + public Builder connectTimeout(Duration connectTimeout) { + Assert.notNull(connectTimeout, "connectTimeout must not be null"); + this.connectTimeout = connectTimeout; + return this; + } + + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

    + * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + + /** + * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link HttpClientStreamableHttpTransport} + */ + public HttpClientStreamableHttpTransport build() { + HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); + return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, + httpRequestCustomizer, supportedProtocolVersions); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java new file mode 100644 index 000000000..29dc23c35 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -0,0 +1,327 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.ResponseInfo; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +import org.reactivestreams.FlowAdapters; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.spec.McpTransportException; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.FluxSink; + +/** + * Utility class providing various {@link BodySubscriber} implementations for handling + * different types of HTTP response bodies in the context of Model Context Protocol (MCP) + * clients. + * + *

    + * Defines subscribers for processing Server-Sent Events (SSE), aggregate responses, and + * bodiless responses. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +class ResponseSubscribers { + + private static final Logger logger = LoggerFactory.getLogger(ResponseSubscribers.class); + + record SseEvent(String id, String event, String data) { + } + + sealed interface ResponseEvent permits SseResponseEvent, AggregateResponseEvent, DummyEvent { + + ResponseInfo responseInfo(); + + } + + record DummyEvent(ResponseInfo responseInfo) implements ResponseEvent { + + } + + record SseResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent) implements ResponseEvent { + } + + record AggregateResponseEvent(ResponseInfo responseInfo, String data) implements ResponseEvent { + } + + static BodySubscriber sseToBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new SseLineSubscriber(responseInfo, sink))); + } + + static BodySubscriber aggregateBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new AggregateSubscriber(responseInfo, sink))); + } + + static BodySubscriber bodilessBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodilessResponseLineSubscriber(responseInfo, sink))); + } + + static class SseLineSubscriber extends BaseSubscriber { + + /** + * Pattern to extract data content from SSE "data:" lines. + */ + private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event ID from SSE "id:" lines. + */ + private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event type from SSE "event:" lines. + */ + private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * Current event's ID, if specified. + */ + private final AtomicReference currentEventId; + + /** + * Current event's type, if specified. + */ + private final AtomicReference currentEventType; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + /** + * Creates a new LineSubscriber that will emit parsed SSE events to the provided + * sink. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public SseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.currentEventId = new AtomicReference<>(); + this.currentEventType = new AtomicReference<>(); + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + subscription.request(n); + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + subscription.cancel(); + }); + } + + @Override + protected void hookOnNext(String line) { + if (line.isEmpty()) { + // Empty line means end of event + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); + this.eventBuilder.setLength(0); + } + } + else { + if (line.startsWith("data:")) { + var matcher = EVENT_DATA_PATTERN.matcher(line); + if (matcher.find()) { + this.eventBuilder.append(matcher.group(1).trim()).append("\n"); + } + upstream().request(1); + } + else if (line.startsWith("id:")) { + var matcher = EVENT_ID_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventId.set(matcher.group(1).trim()); + } + upstream().request(1); + } + else if (line.startsWith("event:")) { + var matcher = EVENT_TYPE_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventType.set(matcher.group(1).trim()); + } + upstream().request(1); + } + else if (line.startsWith(":")) { + // Ignore comment lines starting with ":" + // This is a no-op, just to skip comments + logger.debug("Ignoring comment line: {}", line); + upstream().request(1); + } + else { + // If the response is not successful, emit an error + this.sink.error(new McpTransportException( + "Invalid SSE response. Status code: " + this.responseInfo.statusCode() + " Line: " + line)); + + } + } + } + + @Override + protected void hookOnComplete() { + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); + } + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + static class AggregateSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + volatile boolean hasRequestedDemand = false; + + /** + * Creates a new JsonLineSubscriber that will emit parsed JSON-RPC messages. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public AggregateSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (!hasRequestedDemand) { + subscription.request(Long.MAX_VALUE); + } + hasRequestedDemand = true; + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(subscription::cancel); + } + + @Override + protected void hookOnNext(String line) { + this.eventBuilder.append(line).append("\n"); + } + + @Override + protected void hookOnComplete() { + + if (hasRequestedDemand) { + String data = this.eventBuilder.toString(); + this.sink.next(new AggregateResponseEvent(responseInfo, data)); + } + + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + static class BodilessResponseLineSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + private final ResponseInfo responseInfo; + + volatile boolean hasRequestedDemand = false; + + public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (!hasRequestedDemand) { + subscription.request(Long.MAX_VALUE); + } + hasRequestedDemand = true; + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + subscription.cancel(); + }); + } + + @Override + protected void hookOnComplete() { + if (hasRequestedDemand) { + // emit dummy event to be able to inspect the response info + // this is a shortcut allowing for a more streamlined processing using + // operator composition instead of having to deal with the + // CompletableFuture along the Subscriber for inspecting the result + this.sink.next(new DummyEvent(responseInfo)); + } + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java similarity index 88% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 614c65125..1b4eaca97 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -11,14 +11,13 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.function.Consumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; @@ -38,7 +37,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public class StdioClientTransport implements ClientMcpTransport { +public class StdioClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); @@ -49,7 +48,7 @@ public class StdioClientTransport implements ClientMcpTransport { /** The server process being communicated with */ private Process process; - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; /** Scheduler for handling inbound messages from the server process */ private Scheduler inboundScheduler; @@ -71,29 +70,20 @@ public class StdioClientTransport implements ClientMcpTransport { private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); /** - * Creates a new StdioClientTransport with the specified parameters and default - * ObjectMapper. + * Creates a new StdioClientTransport with the specified parameters and JsonMapper. * @param params The parameters for configuring the server process + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioClientTransport(ServerParameters params) { - this(params, new ObjectMapper()); - } - - /** - * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) { + public StdioClientTransport(ServerParameters params, McpJsonMapper jsonMapper) { Assert.notNull(params, "The params can not be null"); - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.params = params; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); @@ -113,6 +103,7 @@ public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) @Override public Mono connect(Function, Mono> handler) { return Mono.fromRunnable(() -> { + logger.info("MCP server starting."); handleIncomingMessages(handler); handleIncomingErrors(); @@ -143,6 +134,7 @@ public Mono connect(Function, Mono> h startInboundProcessing(); startOutboundProcessing(); startErrorProcessing(); + logger.info("MCP server started"); }).subscribeOn(Schedulers.boundedElastic()); } @@ -258,7 +250,7 @@ private void startInboundProcessing() { String line; while (!isClosing && (line = processReader.readLine()) != null) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { if (!isClosing) { logger.error("Failed to enqueue inbound message: {}", message); @@ -268,7 +260,7 @@ private void startInboundProcessing() { } catch (Exception e) { if (!isClosing) { - logger.error("Error processing inbound message for line: " + line, e); + logger.error("Error processing inbound message for line: {}", line, e); } break; } @@ -293,13 +285,13 @@ private void startInboundProcessing() { */ private void startOutboundProcessing() { this.handleOutbound(messages -> messages - // this bit is important since writes come from user threads and we + // this bit is important since writes come from user threads, and we // want to ensure that the actual writing happens on a dedicated thread .publishOn(outboundScheduler) .handle((message, s) -> { if (message != null && !isClosing) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec: // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio // - Messages are delimited by newlines, and MUST NOT contain @@ -345,26 +337,30 @@ public Mono closeGracefully() { return Mono.fromRunnable(() -> { isClosing = true; logger.debug("Initiating graceful shutdown"); - }).then(Mono.defer(() -> { + }).then(Mono.defer(() -> { // First complete all sinks to stop accepting new messages inboundSink.tryEmitComplete(); outboundSink.tryEmitComplete(); errorSink.tryEmitComplete(); // Give a short time for any pending messages to be processed - return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.fromFuture(() -> { + return Mono.delay(Duration.ofMillis(100)).then(); + })).then(Mono.defer(() -> { logger.debug("Sending TERM to process"); if (this.process != null) { this.process.destroy(); - return process.onExit(); + return Mono.fromFuture(process.onExit()); } else { - return CompletableFuture.failedFuture(new RuntimeException("Process not started")); + logger.warn("Process not started"); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { - logger.warn("Process terminated with code " + process.exitValue()); + logger.warn("Process terminated with code {}", process.exitValue()); + } + else { + logger.info("MCP server process stopped"); } }).then(Mono.fromRunnable(() -> { try { @@ -387,8 +383,8 @@ public Sinks.Many getErrorSink() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..2492efe18 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import org.reactivestreams.Publisher; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +import reactor.core.publisher.Mono; + +/** + * Composable {@link McpAsyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpAsyncHttpClientRequestCustomizer implements McpAsyncHttpClientRequestCustomizer { + + private final List customizers; + + public DelegatingMcpAsyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.customizers = customizers; + } + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body, McpTransportContext context) { + var result = Mono.just(builder); + for (var customizer : this.customizers) { + result = result.flatMap(b -> Mono.from(customizer.customize(b, method, endpoint, body, context))); + } + return result; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e627e7e69 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +/** + * Composable {@link McpSyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpSyncHttpClientRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private final List delegates; + + public DelegatingMcpSyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.delegates = customizers; + } + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + this.delegates.forEach(delegate -> delegate.customize(builder, method, endpoint, body, context)); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..756b39c35 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; + +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, in either SSE or + * Streamable HTTP transport. + *

    + * When used in a non-blocking context, implementations MUST be non-blocking. + * + * @author Daniel Garnier-Moiroux + */ +public interface McpAsyncHttpClientRequestCustomizer { + + Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + @Nullable String body, McpTransportContext context); + + McpAsyncHttpClientRequestCustomizer NOOP = new Noop(); + + /** + * Wrap a sync implementation in an async wrapper. + *

    + * Do NOT wrap a blocking implementation for use in a non-blocking context. For a + * blocking implementation, consider using {@link Schedulers#boundedElastic()}. + */ + static McpAsyncHttpClientRequestCustomizer fromSync(McpSyncHttpClientRequestCustomizer customizer) { + return (builder, method, uri, body, context) -> Mono.fromSupplier(() -> { + customizer.customize(builder, method, uri, body, context); + return builder; + }); + } + + class Noop implements McpAsyncHttpClientRequestCustomizer { + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body, McpTransportContext context) { + return Mono.just(builder); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e22e3aa62 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; + +import reactor.util.annotation.Nullable; + +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or + * Streamable HTTP transport. Do not rely on thread-locals in this implementation, instead + * use {@link SyncSpec#transportContextProvider} to extract context, and then consume it + * through {@link McpTransportContext}. + * + * @author Daniel Garnier-Moiroux + */ +public interface McpSyncHttpClientRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body, + McpTransportContext context); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java new file mode 100644 index 000000000..cde637b15 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation for {@link McpTransportContext} which uses a map as storage. + * + * @author Dariusz Jędrzejczyk + * @author Daniel Garnier-Moiroux + */ +class DefaultMcpTransportContext implements McpTransportContext { + + private final Map metadata; + + DefaultMcpTransportContext(Map metadata) { + Assert.notNull(metadata, "The metadata cannot be null"); + this.metadata = metadata; + } + + @Override + public Object get(String key) { + return this.metadata.get(key); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) + return false; + + DefaultMcpTransportContext that = (DefaultMcpTransportContext) o; + return this.metadata.equals(that.metadata); + } + + @Override + public int hashCode() { + return this.metadata.hashCode(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java new file mode 100644 index 000000000..46a2ccf84 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Collections; +import java.util.Map; + +/** + * Context associated with the transport layer. It allows to add transport-level metadata + * for use further down the line. Specifically, it can be beneficial to extract HTTP + * request metadata for use in MCP feature implementations. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportContext { + + /** + * Key for use in Reactor Context to transport the context to user land. + */ + String KEY = "MCP_TRANSPORT_CONTEXT"; + + /** + * An empty, unmodifiable context. + */ + @SuppressWarnings("unchecked") + McpTransportContext EMPTY = new DefaultMcpTransportContext(Collections.EMPTY_MAP); + + /** + * Create an unmodifiable context containing the given metadata. + * @param metadata the transport metadata + * @return the context containing the metadata + */ + static McpTransportContext create(Map metadata) { + return new DefaultMcpTransportContext(metadata); + } + + /** + * Extract a value from the context. + * @param key the key under the data is expected + * @return the associated value or {@code null} if missing. + */ + Object get(String key); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java new file mode 100644 index 000000000..d1b55f594 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.Map; + +class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpStatelessServerHandler.class); + + Map> requestHandlers; + + Map notificationHandlers; + + public DefaultMcpStatelessServerHandler(Map> requestHandlers, + Map notificationHandlers) { + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public Mono handleRequest(McpTransportContext transportContext, + McpSchema.JSONRPCRequest request) { + McpStatelessRequestHandler requestHandler = this.requestHandlers.get(request.method()); + if (requestHandler == null) { + return Mono.error(new McpError("Missing handler for request type: " + request.method())); + } + return requestHandler.handle(transportContext, request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(t -> { + McpSchema.JSONRPCResponse.JSONRPCError error; + if (t instanceof McpError mcpError && mcpError.getJsonRpcError() != null) { + error = mcpError.getJsonRpcError(); + } + else { + error = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + t.getMessage(), null); + } + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }); + } + + @Override + public Mono handleNotification(McpTransportContext transportContext, + McpSchema.JSONRPCNotification notification) { + McpStatelessNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.warn("Missing handler for notification type: {}", notification.method()); + return Mono.empty(); + } + return notificationHandler.handle(transportContext, notification.params()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java new file mode 100644 index 000000000..23285d514 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -0,0 +1,1076 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + +/** + * The Model Context Protocol (MCP) server implementation that provides asynchronous + * communication using Project Reactor's Mono and Flux types. + * + *

    + * This server implements the MCP specification, enabling AI models to expose tools, + * resources, and prompts through a standardized interface. Key features include: + *

      + *
    • Asynchronous communication using reactive programming patterns + *
    • Dynamic tool registration and management + *
    • Resource handling with URI-based addressing + *
    • Prompt template management + *
    • Real-time client notifications for state changes + *
    • Structured logging with configurable severity levels + *
    • Support for client-side AI model sampling + *
    + * + *

    + * The server follows a lifecycle: + *

      + *
    1. Initialization - Accepts client connections and negotiates capabilities + *
    2. Normal Operation - Handles client requests and sends notifications + *
    3. Graceful Shutdown - Ensures clean connection termination + *
    + * + *

    + * This implementation uses Project Reactor for non-blocking operations, making it + * suitable for high-throughput scenarios and reactive applications. All operations return + * Mono or Flux types that can be composed into reactive pipelines. + * + *

    + * The server supports runtime modification of its capabilities through methods like + * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying + * connected clients of changes when configured to do so. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + * @see McpServer + * @see McpSchema + * @see McpClientSession + */ +public class McpAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); + + private final McpServerTransportProviderBase mcpTransportProvider; + + private final McpJsonMapper jsonMapper; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + /** + * Create a new McpAsyncServer with the given transport provider and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. + * @param features The MCP server supported features. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization + */ + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.jsonMapper = jsonMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.putAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + this.protocolVersions = mcpTransportProvider.protocolVersions(); + + mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.jsonMapper = jsonMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.putAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + this.protocolVersions = mcpTransportProvider.protocolVersions(); + + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + private Map prepareNotificationHandlers(McpServerFeatures.Async features) { + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + return notificationHandlers; + } + + private Map> prepareRequestHandlers() { + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + return requestHandlers; + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + private McpNotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> Mono.defer(() -> consumer.apply(exchange, listRootsResult.roots()))) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool call specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new IllegalArgumentException("Tool must not be null")); + } + if (toolSpecification.call() == null && toolSpecification.callHandler() == null) { + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); + + return Mono.defer(() -> { + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); + } + + this.tools.add(wrappedToolSpecification); + logger.debug("Added tool handler: {}", wrappedToolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + private static class StructuredOutputCallToolHandler + implements BiFunction> { + + private final BiFunction> delegateCallToolResult; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final Map outputSchema; + + public StructuredOutputCallToolHandler(JsonSchemaValidator jsonSchemaValidator, + Map outputSchema, + BiFunction> delegateHandler) { + + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + Assert.notNull(delegateHandler, "Delegate call tool result handler must not be null"); + + this.delegateCallToolResult = delegateHandler; + this.outputSchema = outputSchema; + this.jsonSchemaValidator = jsonSchemaValidator; + } + + @Override + public Mono apply(McpAsyncServerExchange exchange, McpSchema.CallToolRequest request) { + + return this.delegateCallToolResult.apply(exchange, request).map(result -> { + + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + + if (outputSchema == null) { + if (result.structuredContent() != null) { + logger.warn( + "Tool call with no outputSchema is not expected to have a result with structured content, but got: {}", + result.structuredContent()); + } + // Pass through. No validation is required if no output schema is + // provided. + return result; + } + + // If an output schema is provided, servers MUST provide structured + // results that conform to this schema. + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + if (result.structuredContent() == null) { + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); + } + + // Validate the result against the output schema + var validation = this.jsonSchemaValidator.validate(outputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); + } + + if (Utils.isEmpty(result.content())) { + // For backwards compatibility, a tool that returns structured + // content SHOULD also return functionally equivalent unstructured + // content. (For example, serialized JSON can be returned in a + // TextContent block.) + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); + } + + return result; + }); + } + + } + + private static List withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, List tools) { + + if (Utils.isEmpty(tools)) { + return tools; + } + + return tools.stream().map(tool -> withStructuredOutputHandling(jsonSchemaValidator, tool)).toList(); + } + + private static McpServerFeatures.AsyncToolSpecification withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, McpServerFeatures.AsyncToolSpecification toolSpecification) { + + if (toolSpecification.callHandler() instanceof StructuredOutputCallToolHandler) { + // If the tool is already wrapped, return it as is + return toolSpecification; + } + + if (toolSpecification.tool().outputSchema() == null) { + // If the tool does not have an output schema, return it as is + return toolSpecification; + } + + return McpServerFeatures.AsyncToolSpecification.builder() + .tool(toolSpecification.tool()) + .callHandler(new StructuredOutputCallToolHandler(jsonSchemaValidator, + toolSpecification.tool().outputSchema(), toolSpecification.callHandler())) + .build(); + } + + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpServerFeatures.AsyncToolSpecification::tool); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new IllegalArgumentException("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + } + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpRequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpRequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); + } + + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new IllegalArgumentException("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resources")); + } + + return Mono.defer(() -> { + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + } + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()).map(McpServerFeatures.AsyncResourceSpecification::resource); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resources")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + else { + logger.warn("Ignore as a Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates.remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + /** + * Notifies clients that the resources have updated. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, + resourcesUpdatedNotification); + } + + private McpRequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpRequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; + } + + private McpRequestHandler resourcesReadRequestHandler() { + return (ex, params) -> { + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); + + var resourceUri = resourceRequest.uri(); + + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); + }; + } + + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } + + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + } + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + + return Mono.empty(); + }); + } + + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()).map(McpServerFeatures.AsyncPromptSpecification::prompt); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + return Mono.empty(); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpRequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpRequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); + } + + return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + * @deprecated Use + * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. + */ + @Deprecated + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + + private McpRequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { + return Mono.defer(() -> { + + SetLevelRequest newMinLoggingLevel = jsonMapper.convertValue(params, new TypeRef() { + }); + + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); + + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); + + return Mono.just(Map.of()); + }); + }; + } + + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + + private McpRequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); + } + + if (request.ref().type() == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); + } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; + } + } + + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + McpServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } + } + + // Handle the completion request using the registered handler + // for the given reference. + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); + } + + return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

    + * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + Map contextMap = (Map) params.get("context"); + Map meta = (Map) params.get("_meta"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + McpSchema.CompleteRequest.CompleteContext context = null; + if (contextMap != null) { + Map arguments = (Map) contextMap.get("arguments"); + context = new McpSchema.CompleteRequest.CompleteContext(arguments); + } + + return new McpSchema.CompleteRequest(ref, argument, meta, context); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java new file mode 100644 index 000000000..a15c58cd5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -0,0 +1,260 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import java.util.ArrayList; +import java.util.Collections; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpLoggableSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSession; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Represents an asynchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + * @author Christian Tzolov + */ +public class McpAsyncServerExchange { + + private final String sessionId; + + private final McpLoggableSession session; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + private final McpTransportContext transportContext; + + private static final TypeRef CREATE_MESSAGE_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef LIST_ROOTS_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef ELICITATION_RESULT_TYPE_REF = new TypeRef<>() { + }; + + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { + }; + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @deprecated Use + * {@link #McpAsyncServerExchange(String, McpLoggableSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} + */ + @Deprecated + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo) { + this.sessionId = null; + if (!(session instanceof McpLoggableSession)) { + throw new IllegalArgumentException("Expecting session to be a McpLoggableSession instance"); + } + this.session = (McpLoggableSession) session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.transportContext = McpTransportContext.EMPTY; + } + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @param transportContext context associated with the client as extracted from the + * transport + */ + public McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext) { + this.sessionId = sessionId; + this.session = session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.transportContext = transportContext; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Provides the {@link McpTransportContext} associated with the transport layer. For + * HTTP transports it can contain the metadata associated with the HTTP request that + * triggered the processing. + * @return the transport context object + */ + public McpTransportContext transportContext() { + return this.transportContext; + } + + /** + * Provides the Session ID. + * @return session ID string + */ + public String sessionId() { + return this.sessionId; + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A Mono that completes when the elicitation has been resolved. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + ELICITATION_RESULT_TYPE_REF); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + + // @formatter:off + return this.listRoots(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? + this.listRoots(result.nextCursor()) : Mono.empty()) + .reduce(new McpSchema.ListRootsResult(new ArrayList<>(), null), + (allRootsResult, result) -> { + allRootsResult.roots().addAll(result.roots()); + return allRootsResult; + }) + .map(result -> new McpSchema.ListRootsResult(Collections.unmodifiableList(result.roots()), + result.nextCursor())); + // @formatter:on + } + + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + + /** + * Send a logging message notification to the client. Messages below the current + * minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + return Mono.defer(() -> { + if (this.session.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); + } + return Mono.empty(); + }); + } + + /** + * Sends a notification to the client that the current progress status has changed for + * long-running operations. + * @param progressNotification The progress notification to send + * @return A Mono that completes when the notification has been sent + */ + public Mono progressNotification(McpSchema.ProgressNotification progressNotification) { + if (progressNotification == null) { + return Mono.error(new McpError("Progress notification must not be null")); + } + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification); + } + + /** + * Sends a ping request to the client. + * @return A Mono that completes with clients's ping response + */ + public Mono ping() { + return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF); + } + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.session.setMinLoggingLevel(minLoggingLevel); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java new file mode 100644 index 000000000..13ff45a54 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Request handler for the initialization request. + */ +public interface McpInitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java new file mode 100644 index 000000000..6b1061c03 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated notifications. + */ +public interface McpNotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java new file mode 100644 index 000000000..c9d70ad04 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ +public interface McpRequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java new file mode 100644 index 000000000..fe3125271 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -0,0 +1,2362 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.json.McpJsonMapper; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import reactor.core.publisher.Mono; + +/** + * Factory class for creating Model Context Protocol (MCP) servers. MCP servers expose + * tools, resources, and prompts to AI models through a standardized interface. + * + *

    + * This class serves as the main entry point for implementing the server-side of the MCP + * specification. The server's responsibilities include: + *

      + *
    • Exposing tools that models can invoke to perform actions + *
    • Providing access to resources that give models context + *
    • Managing prompt templates for structured model interactions + *
    • Handling client connections and requests + *
    • Implementing capability negotiation + *
    + * + *

    + * Thread Safety: Both synchronous and asynchronous server implementations are + * thread-safe. The synchronous server processes requests sequentially, while the + * asynchronous server can handle concurrent requests safely through its reactive + * programming model. + * + *

    + * Error Handling: The server implementations provide robust error handling through the + * McpError class. Errors are properly propagated to clients while maintaining the + * server's stability. Server implementations should use appropriate error codes and + * provide meaningful error messages to help diagnose issues. + * + *

    + * The class provides factory methods to create either: + *

      + *
    • {@link McpAsyncServer} for non-blocking operations with reactive responses + *
    • {@link McpSyncServer} for blocking operations with direct responses + *
    + * + *

    + * Example of creating a basic synchronous server:

    {@code
    + * McpServer.sync(transportProvider)
    + *     .serverInfo("my-server", "1.0.0")
    + *     .tool(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
    + *           (exchange, args) -> CallToolResult.builder()
    + *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(args))))
    + *                   .isError(false)
    + *                   .build())
    + *     .build();
    + * }
    + * + * Example of creating a basic asynchronous server:
    {@code
    + * McpServer.async(transportProvider)
    + *     .serverInfo("my-server", "1.0.0")
    + *     .tool(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
    + *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    + *               .map(result -> CallToolResult.builder()
    + *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
    + *                   .isError(false)
    + *                   .build()))
    + *     .build();
    + * }
    + * + *

    + * Example with comprehensive asynchronous configuration:

    {@code
    + * McpServer.async(transportProvider)
    + *     .serverInfo("advanced-server", "2.0.0")
    + *     .capabilities(new ServerCapabilities(...))
    + *     // Register tools
    + *     .tools(
    + *         McpServerFeatures.AsyncToolSpecification.builder()
    + * 			.tool(calculatorTool)
    + *   	    .callTool((exchange, args) -> Mono.fromSupplier(() -> calculate(args.arguments()))
    + *                 .map(result -> CallToolResult.builder()
    + *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
    + *                   .isError(false)
    + *                   .build()))
    + *.         .build(),
    + *         McpServerFeatures.AsyncToolSpecification.builder()
    + * 	        .tool((weatherTool)
    + *          .callTool((exchange, args) -> Mono.fromSupplier(() -> getWeather(args.arguments()))
    + *                 .map(result -> CallToolResult.builder()
    + *                   .content(List.of(new McpSchema.TextContent("Weather: " + result)))
    + *                   .isError(false)
    + *                   .build()))
    + *          .build()
    + *     )
    + *     // Register resources
    + *     .resources(
    + *         new McpServerFeatures.AsyncResourceSpecification(fileResource,
    + *             (exchange, req) -> Mono.fromSupplier(() -> readFile(req))
    + *                 .map(ReadResourceResult::new)),
    + *         new McpServerFeatures.AsyncResourceSpecification(dbResource,
    + *             (exchange, req) -> Mono.fromSupplier(() -> queryDb(req))
    + *                 .map(ReadResourceResult::new))
    + *     )
    + *     // Add resource templates
    + *     .resourceTemplates(
    + *         new ResourceTemplate("file://{path}", "Access files"),
    + *         new ResourceTemplate("db://{table}", "Access database")
    + *     )
    + *     // Register prompts
    + *     .prompts(
    + *         new McpServerFeatures.AsyncPromptSpecification(analysisPrompt,
    + *             (exchange, req) -> Mono.fromSupplier(() -> generateAnalysisPrompt(req))
    + *                 .map(GetPromptResult::new)),
    + *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
    + *             (exchange, req) -> Mono.fromSupplier(() -> generateSummaryPrompt(req))
    + *                 .map(GetPromptResult::new))
    + *     )
    + *     .build();
    + * }
    + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + * @see McpAsyncServer + * @see McpSyncServer + * @see McpServerTransportProvider + */ +public interface McpServer { + + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("Java SDK MCP Server", "0.15.0"); + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static SingleSessionSyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SingleSessionSyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new SingleSessionAsyncSpecification(transportProvider); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StreamableSyncSpecification sync(McpStreamableServerTransportProvider transportProvider) { + return new StreamableSyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { + return new StreamableServerAsyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transport The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static StatelessAsyncSpecification async(McpStatelessServerTransport transport) { + return new StatelessAsyncSpecification(transport); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transport The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StatelessSyncSpecification sync(McpStatelessServerTransport transport) { + return new StatelessSyncSpecification(transport); + } + + class SingleSessionAsyncSpecification extends AsyncSpecification { + + private final McpServerTransportProvider transportProvider; + + private SingleSessionAsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + class StreamableServerAsyncSpecification extends AsyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider transportProvider) { + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + /** + * Asynchronous server specification. + */ + abstract class AsyncSpecification> { + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + + Duration requestTimeout = Duration.ofHours(10); // Default timeout + + public abstract McpAsyncServer build(); + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public AsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public AsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public AsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
    +		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    +		 *         .map(result -> CallToolResult.builder()
    +		 *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
    +		 *                   .isError(false)
    +		 *                   .build()))
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpAsyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * map of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + * @deprecated Use {@link #toolCall(McpSchema.Tool, BiFunction)} instead for tool + * calls that require a request object. + */ + @Deprecated + public AsyncSpecification tool(McpSchema.Tool tool, + BiFunction, Mono> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public AsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools + .add(McpServerFeatures.AsyncToolSpecification.builder().tool(tool).callHandler(callHandler).build()); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.AsyncToolSpecification...) + */ + public AsyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
    +		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
    +		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates Map of template URI to specification. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public AsyncSpecification resourceTemplates( + List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates List of template URI to specification. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public AsyncSpecification resourceTemplates( + McpServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpServerFeatures.AsyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
    +		 *         .map(GetPromptResult::new)
    +		 * )));
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) + */ + public AsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second argument is the list of roots. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public AsyncSpecification rootsChangeHandler( + BiFunction, Mono> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiFunction) + */ + public AsyncSpecification rootsChangeHandlers( + List, Mono>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) + */ + public AsyncSpecification rootsChangeHandlers( + @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(Arrays.asList(handlers)); + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public AsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + } + + class SingleSessionSyncSpecification extends SyncSpecification { + + private final McpServerTransportProvider transportProvider; + + private SingleSessionSyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); + return new McpSyncServer(asyncServer, this.immediateExecution); + } + + } + + class StreamableSyncSpecification extends SyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + private StreamableSyncSpecification(McpStreamableServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpSyncServer(asyncServer, this.immediateExecution); + } + + } + + /** + * Synchronous server specification. + */ + abstract class SyncSpecification> { + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + JsonSchemaValidator jsonSchemaValidator; + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + final List>> rootsChangeHandlers = new ArrayList<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + boolean immediateExecution = false; + + public abstract McpSyncServer build(); + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return this builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public SyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public SyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public SyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     Tool.builder().name("calculator").title("Performs calculations".inputSchema(schema).build(),
    +		 *     (exchange, args) -> CallToolResult.builder()
    +		 *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(args))))
    +		 *                   .isError(false)
    +		 *                   .build())
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + * @deprecated Use {@link #toolCall(McpSchema.Tool, BiFunction)} instead for tool + * calls that require a request object. + */ + @Deprecated + public SyncSpecification tool(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public SyncSpecification toolCall(McpSchema.Tool tool, + BiFunction handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, null, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.SyncToolSpecification...) + */ + public SyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + String toolName = tool.tool().name(); + assertNoDuplicateTool(toolName); // Check against existing tools + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     new ToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new ToolSpecification(weatherTool, weatherHandler),
    +		 *     new ToolSpecification(fileManagerTool, fileManagerHandler)
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(List) + */ + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new ResourceSpecification(fileResource, fileHandler),
    +		 *     new ResourceSpecification(dbResource, dbHandler),
    +		 *     new ResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * @param resourceTemplates List of resource template specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public SyncSpecification resourceTemplates( + List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpServerFeatures.SyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); + } + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null + * @see #resourceTemplates(List) + */ + public SyncSpecification resourceTemplates( + McpServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * Map prompts = new HashMap<>();
    +		 * prompts.put("analysis", new PromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     (exchange, request) -> new GetPromptResult(generateAnalysisPrompt(request))
    +		 * ));
    +		 * .prompts(prompts)
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.SyncPromptSpecification...) + */ + public SyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new PromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new PromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new PromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + * @see #completions(McpServerFeatures.SyncCompletionSpecification...) + */ + public SyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpSyncServerExchange} upon which the server can interact + * with the connected client. The second argument is the list of roots. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public SyncSpecification rootsChangeHandler( + BiConsumer> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiConsumer) + */ + public SyncSpecification rootsChangeHandlers( + List>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) + */ + public SyncSpecification rootsChangeHandlers( + BiConsumer>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(List.of(handlers)); + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public SyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpAsyncServer}. Defaults to false, which does blocking code offloading + * to prevent accidental blocking of the non-blocking transport. + *

    + * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public SyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + } + + class StatelessAsyncSpecification { + + private final McpStatelessServerTransport transport; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessAsyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessAsyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessAsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessAsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessAsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *

      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessAsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.AsyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.AsyncToolSpecification...) + */ + public StatelessAsyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
    +		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
    +		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessAsyncSpecification tools( + McpStatelessServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessAsyncSpecification resources( + McpStatelessServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public StatelessAsyncSpecification resourceTemplates( + List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessAsyncSpecification resourceTemplates( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
    +		 *         .map(GetPromptResult::new)
    +		 * )));
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.AsyncPromptSpecification...) + */ + public StatelessAsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts(McpStatelessServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + McpStatelessServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public StatelessAsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + public McpStatelessAsyncServer build() { + var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); + } + + } + + class StatelessSyncSpecification { + + private final McpStatelessServerTransport transport; + + boolean immediateExecution = false; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessSyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessSyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessSyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessSyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessSyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessSyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessSyncSpecification toolCall(McpSchema.Tool tool, + BiFunction callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.SyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.SyncToolSpecification...) + */ + public StatelessSyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     McpServerFeatures.SyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
    +		 *     McpServerFeatures.SyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
    +		 *     McpServerFeatures.SyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessSyncSpecification tools( + McpStatelessServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new McpServerFeatures.SyncResourceSpecification(fileResource, fileHandler),
    +		 *     new McpServerFeatures.SyncResourceSpecification(dbResource, dbHandler),
    +		 *     new McpServerFeatures.SyncResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessSyncSpecification resources( + McpStatelessServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * @param resourceTemplatesSpec List of resource templates. If null, clears + * existing templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public StatelessSyncSpecification resourceTemplates( + List resourceTemplatesSpec) { + Assert.notNull(resourceTemplatesSpec, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplatesSpec) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessSyncSpecification resourceTemplates( + McpStatelessServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.SyncPromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
    +		 *         .map(GetPromptResult::new)
    +		 * )));
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.SyncPromptSpecification...) + */ + public StatelessSyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new McpServerFeatures.SyncPromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.SyncPromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.SyncPromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts(McpStatelessServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + McpStatelessServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public StatelessSyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessSyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpStatelessAsyncServer}. Defaults to false, which does blocking code + * offloading to prevent accidental blocking of the non-blocking transport. + *

    + * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public StatelessSyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + public McpStatelessSyncServer build() { + var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + var asyncServer = new McpStatelessAsyncServer(transport, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + this.jsonSchemaValidator != null ? this.jsonSchemaValidator : JsonSchemaValidator.getDefault()); + return new McpStatelessSyncServer(asyncServer, this.immediateExecution); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java new file mode 100644 index 000000000..fe0608b1c --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -0,0 +1,715 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * MCP server features specification that a particular server can choose to support. + * + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + */ +public class McpServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when the + * roots list changes + * @param instructions The server instructions text + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + List, Mono>> rootsChangeConsumers, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when + * the roots list changes + * @param instructions The server instructions text + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + List, Mono>> rootsChangeConsumers, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); + this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); + this.instructions = instructions; + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec, boolean immediateExecution) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); + }); + + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); + }); + + List, Mono>> rootChangeConsumers = new ArrayList<>(); + + for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { + rootChangeConsumers.add((exchange, list) -> Mono + .fromRunnable(() -> rootChangeConsumer.accept(new McpSyncServerExchange(exchange), list)) + .subscribeOn(Schedulers.boundedElastic())); + } + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, rootChangeConsumers, syncSpec.instructions()); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when the + * roots list changes + * @param instructions The server instructions text + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + List>> rootsChangeConsumers, String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when + * the roots list changes + * @param instructions The server instructions text + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + List>> rootsChangeConsumers, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); + this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); + this.instructions = instructions; + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability. + * + * @param tool The tool definition including name, description, and parameter schema + * @param call Deprecated. Use the {@link AsyncToolSpecification#callHandler} instead. + * @param callHandler The function that implements the tool's logic, receiving a + * {@link McpAsyncServerExchange} and a + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning + * results. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second arguments is a + * map of tool arguments. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + @Deprecated BiFunction, Mono> call, + BiFunction> callHandler) { + + /** + * @deprecated Use {@link AsyncToolSpecification(McpSchema.Tool, null, + * BiFunction)} instead. + **/ + @Deprecated + public AsyncToolSpecification(McpSchema.Tool tool, + BiFunction, Mono> call) { + this(tool, call, (exchange, toolReq) -> call.apply(exchange, toolReq.arguments())); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { + + // FIXME: This is temporary, proper validation should be implemented + if (syncToolSpec == null) { + return null; + } + + BiFunction, Mono> deprecatedCall = (syncToolSpec + .call() != null) ? (exchange, map) -> { + var toolResult = Mono + .fromCallable(() -> syncToolSpec.call().apply(new McpSyncServerExchange(exchange), map)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + } : null; + + BiFunction> callHandler = ( + exchange, req) -> { + var toolResult = Mono + .fromCallable(() -> syncToolSpec.callHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), deprecatedCall, callHandler); + } + + /** + * Builder for creating AsyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction> callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the AsyncToolSpecification instance. + * @return a new AsyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public AsyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Call handler function must not be null"); + + return new AsyncToolSpecification(tool, null, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *

      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification: + * + *

    {@code
    +	 * new McpServerFeatures.AsyncResourceSpecification(
    +	 *     Resource.builder()
    +	 *         .uri("docs")
    +	 *         .name("Documentation files")
    +	 * 		   .title("Documentation files")
    +	 * 		   .mimeType("text/markdown")
    +	 * 		   .description("Markdown documentation files")
    +	 * 		   .build(),
    +	 * 		(exchange, request) -> Mono.fromSupplier(() -> readFile(request.getPath()))
    +	 * 				.map(ReadResourceResult::new))
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
      + *
    • Parameterized resource definitions + *
    • Dynamic content generation + *
    • Consistent resource formatting + *
    • Contextual data injection + *
    + * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification: + * + *

    {@code
    +	 * new McpServerFeatures.AsyncPromptSpecification(
    +	 * 		new Prompt("analyze", "Code analysis template"),
    +	 * 		(exchange, request) -> {
    +	 * 			String code = request.getArguments().get("code");
    +	 * 			return Mono.just(new GetPromptResult(
    +	 * 					"Analyze this code:\n\n" + code + "\n\nProvide feedback on:"));
    +	 * 		})
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), (exchange, req) -> { + var promptResult = Mono + .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *
      + *
    • Customizable response generation logic + *
    • Parameter-driven template expansion + *
    • Dynamic interaction with connected clients + *
    + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The first argument is an + * {@link McpAsyncServerExchange} used to interact with the client. The second + * argument is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), (exchange, request) -> { + var completionResult = Mono.fromCallable( + () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. + * + *

    + * Example tool specification: + * + *

    {@code
    +	 * McpServerFeatures.SyncToolSpecification.builder()
    +	 * 		.tool(Tool.builder()
    +	 * 				.name("calculator")
    +	 * 				.title("Performs mathematical calculations")
    +	 * 				.inputSchema(new JsonSchemaObject()
    +	 * 						.required("expression")
    +	 * 						.property("expression", JsonSchemaType.STRING))
    +	 * 				.build()
    +	 * 		.toolHandler((exchange, req) -> {
    +	 * 			String expr = (String) req.arguments().get("expression");
    +	 * 			return CallToolResult.builder()
    +	 *                   .content(List.of(new McpSchema.TextContent("Result: " + evaluate(expr))))
    +	 *                   .isError(false)
    +	 *                   .build();
    +	 * 		}))
    +	 *      .build();
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call (Deprected) The function that implements the tool's logic, receiving + * arguments and returning results. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * @param callHandler The function that implements the tool's logic, receiving a + * {@link McpSyncServerExchange} and a + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning + * results. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the client. The second arguments is a map of + * arguments passed to the tool. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + @Deprecated BiFunction, McpSchema.CallToolResult> call, + BiFunction callHandler) { + + @Deprecated + public SyncToolSpecification(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> call) { + this(tool, call, (exchange, toolReq) -> call.apply(exchange, toolReq.arguments())); + } + + /** + * Builder for creating SyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the SyncToolSpecification instance. + * @return a new SyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public SyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "CallTool function must not be null"); + + return new SyncToolSpecification(tool, null, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification: + * + *

    {@code
    +	 * new McpServerFeatures.SyncResourceSpecification(
    +	 *     Resource.builder()
    +	 *         .uri("docs")
    +	 *         .name("Documentation files")
    +	 * 		   .title("Documentation files")
    +	 * 		   .mimeType("text/markdown")
    +	 * 		   .description("Markdown documentation files")
    +	 * 		   .build(),
    +	 * 		(exchange, request) -> {
    +	 * 			String content = readFile(request.getPath());
    +	 * 			return new ReadResourceResult(content);
    +	 * 		})
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
      + *
    • Parameterized resource definitions + *
    • Dynamic content generation + *
    • Consistent resource formatting + *
    • Contextual data injection + *
    + * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification: + * + *

    {@code
    +	 * new McpServerFeatures.SyncPromptSpecification(
    +	 * 		new Prompt("analyze", "Code analysis template"),
    +	 * 		(exchange, request) -> {
    +	 * 			String code = request.getArguments().get("code");
    +	 * 			return new GetPromptResult(
    +	 * 					"Analyze this code:\n\n" + code + "\n\nProvide feedback on:");
    +	 * 		})
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The first argument is an + * {@link McpSyncServerExchange} used to interact with the client. The second argument + * is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java new file mode 100644 index 000000000..c7a1fd0d7 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -0,0 +1,851 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncResourceTemplateSpecification; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + +/** + * A stateless MCP server implementation for use with Streamable HTTP transport types. It + * allows simple horizontal scalability since it does not maintain a session and does not + * require initialization. Each instance of the server can be reached with no prior + * knowledge and can serve the clients with the capabilities it supports. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessAsyncServer.class); + + private final McpStatelessServerTransport mcpTransportProvider; + + private final McpJsonMapper jsonMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + private final JsonSchemaValidator jsonSchemaValidator; + + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper, + McpStatelessServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransport; + this.jsonMapper = jsonMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.putAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (ctx, params) -> Mono.just(Map.of())); + + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + this.protocolVersions = new ArrayList<>(mcpTransport.protocolVersions()); + + McpStatelessServerHandler handler = new DefaultMcpStatelessServerHandler(requestHandlers, Map.of()); + mcpTransport.setMcpHandler(handler); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private McpStatelessRequestHandler asyncInitializeRequestHandler() { + return (ctx, req) -> Mono.defer(() -> { + McpSchema.InitializeRequest initializeRequest = this.jsonMapper.convertValue(req, + McpSchema.InitializeRequest.class); + + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + private static List withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, List tools) { + + if (Utils.isEmpty(tools)) { + return tools; + } + + return tools.stream().map(tool -> withStructuredOutputHandling(jsonSchemaValidator, tool)).toList(); + } + + private static McpStatelessServerFeatures.AsyncToolSpecification withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, + McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + + if (toolSpecification.callHandler() instanceof StructuredOutputCallToolHandler) { + // If the tool is already wrapped, return it as is + return toolSpecification; + } + + if (toolSpecification.tool().outputSchema() == null) { + // If the tool does not have an output schema, return it as is + return toolSpecification; + } + + return new McpStatelessServerFeatures.AsyncToolSpecification(toolSpecification.tool(), + new StructuredOutputCallToolHandler(jsonSchemaValidator, toolSpecification.tool().outputSchema(), + toolSpecification.callHandler())); + } + + private static class StructuredOutputCallToolHandler + implements BiFunction> { + + private final BiFunction> delegateHandler; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final Map outputSchema; + + public StructuredOutputCallToolHandler(JsonSchemaValidator jsonSchemaValidator, + Map outputSchema, + BiFunction> delegateHandler) { + + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + Assert.notNull(delegateHandler, "Delegate call tool result handler must not be null"); + + this.delegateHandler = delegateHandler; + this.outputSchema = outputSchema; + this.jsonSchemaValidator = jsonSchemaValidator; + } + + @Override + public Mono apply(McpTransportContext transportContext, McpSchema.CallToolRequest request) { + + return this.delegateHandler.apply(transportContext, request).map(result -> { + + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + + if (outputSchema == null) { + if (result.structuredContent() != null) { + logger.warn( + "Tool call with no outputSchema is not expected to have a result with structured content, but got: {}", + result.structuredContent()); + } + // Pass through. No validation is required if no output schema is + // provided. + return result; + } + + // If an output schema is provided, servers MUST provide structured + // results that conform to this schema. + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + if (result.structuredContent() == null) { + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); + } + + // Validate the result against the output schema + var validation = this.jsonSchemaValidator.validate(outputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); + } + + if (Utils.isEmpty(result.content())) { + // For backwards compatibility, a tool that returns structured + // content SHOULD also return functionally equivalent unstructured + // content. (For example, serialized JSON can be returned in a + // TextContent block.) + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); + } + + return result; + }); + } + + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new IllegalArgumentException("Tool must not be null")); + } + if (toolSpecification.callHandler() == null) { + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); + + return Mono.defer(() -> { + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); + } + + this.tools.add(wrappedToolSpecification); + logger.debug("Added tool handler: {}", wrappedToolSpecification.tool().name()); + + return Mono.empty(); + }); + } + + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpStatelessServerFeatures.AsyncToolSpecification::tool); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new IllegalArgumentException("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + + logger.debug("Removed tool handler: {}", toolName); + } + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); + }); + } + + private McpStatelessRequestHandler toolsListRequestHandler() { + return (ctx, params) -> { + List tools = this.tools.stream() + .map(McpStatelessServerFeatures.AsyncToolSpecification::tool) + .toList(); + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpStatelessRequestHandler toolsCallRequestHandler() { + return (ctx, params) -> { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); + } + + return toolSpecification.get().callHandler().apply(ctx, callToolRequest); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new IllegalArgumentException("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + } + return Mono.empty(); + }); + } + + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()) + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + } + else { + logger.warn("Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpStatelessServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates + .remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); + }); + } + + private McpStatelessRequestHandler resourcesListRequestHandler() { + return (ctx, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpStatelessRequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; + } + + private McpStatelessRequestHandler resourcesReadRequestHandler() { + return (ctx, params) -> { + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); + var resourceUri = resourceRequest.uri(); + + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); + + }; + } + + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } + + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + } + + return Mono.empty(); + }); + } + + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()) + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + return Mono.empty(); + } + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + + return Mono.empty(); + }); + } + + private McpStatelessRequestHandler promptsListRequestHandler() { + return (ctx, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpStatelessRequestHandler promptsGetRequestHandler() { + return (ctx, params) -> { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + // Implement prompt retrieval logic here + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); + } + + return specification.promptHandler().apply(ctx, promptRequest); + }; + } + + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + + private McpStatelessRequestHandler completionCompleteRequestHandler() { + return (ctx, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); + } + + if (request.ref().type() == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts + .get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); + } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; + } + } + + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } + } + + McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); + } + + return specification.completionHandler().apply(ctx, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

    + * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java new file mode 100644 index 000000000..a2fabb283 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP notifications in a stateless server. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessNotificationHandler { + + /** + * Handle to notification and complete once done. + * @param transportContext {@link McpTransportContext} associated with the transport + * @param params the payload of the MCP notification + * @return Mono which completes once the processing is done + */ + Mono handle(McpTransportContext transportContext, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java new file mode 100644 index 000000000..37cd3c096 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP requests in a stateless server. + * + * @param type of the MCP response + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessRequestHandler { + + /** + * Handle the request and complete with a result. + * @param transportContext {@link McpTransportContext} associated with the transport + * @param params the payload of the MCP request + * @return Mono which completes with the response object + */ + Mono handle(McpTransportContext transportContext, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java new file mode 100644 index 000000000..a15681ba5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -0,0 +1,555 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * MCP stateless server features specification that a particular server can choose to + * support. + * + * @author Dariusz Jędrzejczyk + * @author Christian Tzolov + */ +public class McpStatelessServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + null, // currently statless server doesn't support set logging + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); + this.instructions = instructions; + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec, boolean immediateExecution) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); + }); + + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); + }); + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, syncSpec.instructions()); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); + this.instructions = instructions; + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning the result. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction> callHandler) { + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { + + // FIXME: This is temporary, proper validation should be implemented + if (syncToolSpec == null) { + return null; + } + + BiFunction> callHandler = (ctx, + req) -> { + var toolResult = Mono.fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); + } + + /** + * Builder for creating AsyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction> callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the AsyncToolSpecification instance. + * @return a new AsyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public AsyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Call handler function must not be null"); + + return new AsyncToolSpecification(tool, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *

      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
      + *
    • Parameterized resource definitions + *
    • Dynamic content generation + *
    • Consistent resource formatting + *
    • Contextual data injection + *
    + * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), (ctx, req) -> { + var promptResult = Mono.fromCallable(() -> prompt.promptHandler().apply(ctx, req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *
      + *
    • Customizable response generation logic + *
    • Parameter-driven template expansion + *
    • Dynamic interaction with connected clients + *
    + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The function's argument is a + * {@link McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), (ctx, req) -> { + var completionResult = Mono.fromCallable(() -> completion.completionHandler().apply(ctx, req)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning results. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction callHandler) { + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating SyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the SyncToolSpecification instance. + * @return a new SyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public SyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "CallTool function must not be null"); + + return new SyncToolSpecification(tool, callHandler); + } + + } + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
      + *
    • Parameterized resource definitions + *
    • Dynamic content generation + *
    • Consistent resource formatting + *
    • Contextual data injection + *
    + * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The argument is a {@link McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java new file mode 100644 index 000000000..cbae58bfd --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP requests and notifications in a Stateless Streamable HTTP Server + * context. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessServerHandler { + + /** + * Handle the request using user-provided feature implementations. + * @param transportContext {@link McpTransportContext} carrying transport layer + * metadata + * @param request the request JSON object + * @return Mono containing the JSON response + */ + Mono handleRequest(McpTransportContext transportContext, + McpSchema.JSONRPCRequest request); + + /** + * Handle the notification. + * @param transportContext {@link McpTransportContext} carrying transport layer + * metadata + * @param notification the notification JSON object + * @return Mono that completes once handling is finished + */ + Mono handleNotification(McpTransportContext transportContext, McpSchema.JSONRPCNotification notification); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java new file mode 100644 index 000000000..6849eb8ed --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -0,0 +1,184 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.List; + +/** + * A stateless MCP server implementation for use with Streamable HTTP transport types. It + * allows simple horizontal scalability since it does not maintain a session and does not + * require initialization. Each instance of the server can be reached with no prior + * knowledge and can serve the clients with the capabilities it supports. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessSyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessSyncServer.class); + + private final McpStatelessAsyncServer asyncServer; + + private final boolean immediateExecution; + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer, boolean immediateExecution) { + this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.asyncServer.getServerCapabilities(); + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.asyncServer.getServerInfo(); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.asyncServer.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.asyncServer.close(); + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + */ + public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecification) { + this.asyncServer + .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + */ + public void removeTool(String toolName) { + this.asyncServer.removeTool(toolName).block(); + } + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + */ + public void addResource(McpStatelessServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + */ + public void removeResource(String resourceUri) { + this.asyncServer.removeResource(resourceUri).block(); + } + + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate( + McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpStatelessServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + */ + public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer + .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + */ + public void removePrompt(String promptName) { + this.asyncServer.removePrompt(promptName).block(); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.asyncServer.setProtocolVersions(protocolVersions); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java similarity index 56% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 1de0139ba..10f0e5a31 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,9 +4,9 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.spec.McpError; +import java.util.List; + import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -56,38 +56,45 @@ public class McpSyncServer { */ private final McpAsyncServer asyncServer; + private final boolean immediateExecution; + /** * Creates a new synchronous server that wraps the provided async server. * @param asyncServer The async server to wrap */ public McpSyncServer(McpAsyncServer asyncServer) { - Assert.notNull(asyncServer, "Async server must not be null"); - this.asyncServer = asyncServer; + this(asyncServer, false); } /** - * Retrieves the list of all roots provided by the client. - * @return The list of roots + * Creates a new synchronous server that wraps the provided async server. + * @param asyncServer The async server to wrap + * @param immediateExecution Tools, prompts, and resources handlers execute work + * without blocking code offloading. Do NOT set to true if the {@code asyncServer}'s + * transport is non-blocking. */ - public McpSchema.ListRootsResult listRoots() { - return this.listRoots(null); + public McpSyncServer(McpAsyncServer asyncServer, boolean immediateExecution) { + Assert.notNull(asyncServer, "Async server must not be null"); + this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; } /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return The list of roots + * Add a new tool handler. + * @param toolHandler The tool handler to add */ - public McpSchema.ListRootsResult listRoots(String cursor) { - return this.asyncServer.listRoots(cursor).block(); + public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { + this.asyncServer + .addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler, this.immediateExecution)) + .block(); } /** - * Add a new tool handler. - * @param toolHandler The tool handler to add + * List all registered tools. + * @return A list of all registered tools */ - public void addTool(McpServerFeatures.SyncToolRegistration toolHandler) { - this.asyncServer.addTool(McpServerFeatures.AsyncToolRegistration.fromSync(toolHandler)).block(); + public List listTools() { + return this.asyncServer.listTools().collectList().block(); } /** @@ -100,10 +107,21 @@ public void removeTool(String toolName) { /** * Add a new resource handler. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource specification to add */ - public void addResource(McpServerFeatures.SyncResourceRegistration resourceHandler) { - this.asyncServer.addResource(McpServerFeatures.AsyncResourceRegistration.fromSync(resourceHandler)).block(); + public void addResource(McpServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); } /** @@ -114,12 +132,50 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate(McpServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + /** * Add a new prompt handler. - * @param promptRegistration The prompt registration to add + * @param promptSpecification The prompt specification to add */ - public void addPrompt(McpServerFeatures.SyncPromptRegistration promptRegistration) { - this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptRegistration.fromSync(promptRegistration)).block(); + public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer + .addPrompt( + McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); } /** @@ -154,26 +210,17 @@ public McpSchema.Implementation getServerInfo() { } /** - * Get the client capabilities that define the supported features and functionality. - * @return The client capabilities - */ - public ClientCapabilities getClientCapabilities() { - return this.asyncServer.getClientCapabilities(); - } - - /** - * Get the client implementation information. - * @return The client implementation details + * Notify clients that the list of available resources has changed. */ - public McpSchema.Implementation getClientInfo() { - return this.asyncServer.getClientInfo(); + public void notifyResourcesListChanged() { + this.asyncServer.notifyResourcesListChanged().block(); } /** - * Notify clients that the list of available resources has changed. + * Notify clients that the resources have updated. */ - public void notifyResourcesListChanged() { - this.asyncServer.notifyResourcesListChanged().block(); + public void notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) { + this.asyncServer.notifyResourcesUpdated(resourcesUpdatedNotification).block(); } /** @@ -184,9 +231,16 @@ public void notifyPromptsListChanged() { } /** - * Send a logging message notification to all clients. - * @param loggingMessageNotification The logging message notification to send + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. + * @param loggingMessageNotification The logging message to send + * @deprecated Use + * {@link McpSyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. */ + @Deprecated public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { this.asyncServer.loggingNotification(loggingMessageNotification).block(); } @@ -213,33 +267,4 @@ public McpAsyncServer getAsyncServer() { return this.asyncServer; } - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling ("completions" or "generations") from language models via clients. - * - *

    - * This flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. - * - *

    - * Unlike its async counterpart, this method blocks until the message creation is - * complete, making it easier to use in synchronous code paths. - * @param createMessageRequest The request to create a new message - * @return The result of the message creation - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - */ - public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - return this.asyncServer.createMessage(createMessageRequest).block(); - } - } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java new file mode 100644 index 000000000..0b9115b79 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -0,0 +1,146 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; + +/** + * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + * @author Christian Tzolov + */ +public class McpSyncServerExchange { + + private final McpAsyncServerExchange exchange; + + /** + * Create a new synchronous exchange with the client using the provided asynchronous + * implementation as a delegate. + * @param exchange The asynchronous exchange to delegate to. + */ + public McpSyncServerExchange(McpAsyncServerExchange exchange) { + this.exchange = exchange; + } + + /** + * Provides the Session ID + * @return session ID + */ + public String sessionId() { + return this.exchange.sessionId(); + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.exchange.getClientInfo(); + } + + /** + * Provides the {@link McpTransportContext} associated with the transport layer. For + * HTTP transports it can contain the metadata associated with the HTTP request that + * triggered the processing. + * @return the transport context object + */ + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A result containing the details of the sampling response + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return this.exchange.createMessage(createMessageRequest).block(); + } + + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A result containing the elicitation response. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitation(elicitRequest).block(); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return The list of roots result. + */ + public McpSchema.ListRootsResult listRoots() { + return this.exchange.listRoots().block(); + } + + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of roots result + */ + public McpSchema.ListRootsResult listRoots(String cursor) { + return this.exchange.listRoots(cursor).block(); + } + + /** + * Send a logging message notification to the client. Messages below the current + * minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + */ + public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { + this.exchange.loggingNotification(loggingMessageNotification).block(); + } + + /** + * Sends a notification to the client that the current progress status has changed for + * long-running operations. + * @param progressNotification The progress notification to send + */ + public void progressNotification(McpSchema.ProgressNotification progressNotification) { + this.exchange.progressNotification(progressNotification).block(); + } + + /** + * Sends a synchronous ping request to the client. + * @return + */ + public Object ping() { + return this.exchange.ping().block(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java new file mode 100644 index 000000000..ea9f05a4f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * The contract for extracting metadata from a generic transport request of type + * {@link T}. + * + * @param transport-specific representation of the request which allows extracting + * metadata for use in the MCP features implementations. + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportContextExtractor { + + /** + * Extract transport-specific metadata from the request into an McpTransportContext. + * @param request the generic representation for the request in the context of a + * specific transport implementation + * @return the context containing the metadata + */ + McpTransportContext extract(T request); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java new file mode 100644 index 000000000..96cebb74a --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -0,0 +1,641 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport + * specification. This implementation provides similar functionality to + * WebFluxSseServerTransportProvider but uses the traditional Servlet API instead of + * WebFlux. + * + *

    + * The transport handles two types of endpoints: + *

      + *
    • SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client + * events
    • + *
    • Message endpoint (configurable) - Handles client-to-server message requests
    • + *
    + * + *

    + * Features: + *

      + *
    • Asynchronous message handling using Servlet 6.0 async support
    • + *
    • Session management for multiple client connections
    • + *
    • Graceful shutdown support
    • + *
    • Error handling and response formatting
    • + *
    + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see HttpServlet + */ + +@WebServlet(asyncSupported = true) +public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** + * Default endpoint path for SSE connections + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** + * Event type for regular messages + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for endpoint information + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + public static final String SESSION_ID = "sessionId"; + + public static final String DEFAULT_BASE_URL = ""; + + /** + * JSON mapper for serialization/deserialization + */ + private final McpJsonMapper jsonMapper; + + /** + * Base URL for the server transport + */ + private final String baseUrl; + + /** + * The endpoint path for handling client messages + */ + private final String messageEndpoint; + + /** + * The endpoint path for handling SSE connections + */ + private final String sseEndpoint; + + /** + * Map of active client sessions, keyed by session ID + */ + private final Map sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is in the process of shutting down + */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** + * Session factory for creating new sessions + */ + private McpServerSession.Factory sessionFactory; + + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param jsonMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality + * @param contextExtractor The extractor for transport context from the request. + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(messageEndpoint, "messageEndpoint must not be null"); + Assert.notNull(sseEndpoint, "sseEndpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.baseUrl = baseUrl; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing.get()) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + /** + * Sets the session factory for creating new sessions. + * @param sessionFactory The session factory to use + */ + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Handles GET requests to establish SSE connections. + *

    + * This method sets up a new SSE connection when a client connects to the SSE + * endpoint. It configures the response headers for SSE, creates a new session, and + * sends the initial endpoint information to the client. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(sseEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + response.setContentType("text/event-stream"); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + PrintWriter writer = response.getWriter(); + + // Create a new session transport + HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, + writer); + + // Create a new session using the session factory + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + // Send initial endpoint event + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, buildEndpointUrl(sessionId)); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + if (this.baseUrl.endsWith("/")) { + return this.baseUrl.substring(0, this.baseUrl.length() - 1) + this.messageEndpoint + "?sessionId=" + + sessionId; + } + return this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId; + } + + /** + * Handles POST requests for client messages. + *

    + * This method processes incoming messages from clients, routes them through the + * session handler, and sends back the appropriate response. It handles error cases + * and formats error responses according to the MCP specification. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(messageEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + // Get the session ID from the request parameter + String sessionId = request.getParameter("sessionId"); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + String jsonError = jsonMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + // Get the session from the sessions map + McpServerSession session = sessions.get(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + String jsonError = jsonMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + final McpTransportContext transportContext = this.contextExtractor.extract(request); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); + + // Process the message through the session's handle method + // Block for Servlet compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + try { + McpError mcpError = new McpError(e.getMessage()); + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + String jsonError = jsonMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Initiates a graceful shutdown of the transport. + *

    + * This method marks the transport as closing and closes all active client sessions. + * New connection attempts will be rejected during shutdown. + * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Sends an SSE event to a client. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

    + * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpServerTransport for HttpServlet SSE sessions. This class + * handles the transport-level communication for a specific client session. + */ + private class HttpServletMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = jsonMapper.writeValueAsString(message); + sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + asyncContext.complete(); + } + }); + } + + /** + * Converts data from one type to another using the configured JsonMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @param The target type + * @return The converted object of type T + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + } + + } + + /** + * Creates a new Builder instance for configuring and creating instances of + * HttpServletSseServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of HttpServletSseServerTransportProvider. + *

    + * This builder provides a fluent API for configuring and creating instances of + * HttpServletSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String baseUrl = DEFAULT_BASE_URL; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Duration keepAliveInterval; + + /** + * Sets the JsonMapper implementation to use for serialization/deserialization. If + * not specified, a JacksonJsonMapper will be created from the configured + * ObjectMapper. + * @param jsonMapper The JsonMapper to use + * @return This builder instance for method chaining + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint path where clients will send their messages. + * @param messageEndpoint The message endpoint path + * @return This builder instance for method chaining + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the endpoint path where clients will establish SSE connections. + *

    + * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be + * used. + * @param sseEndpoint The SSE endpoint path + * @return This builder instance for method chaining + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public HttpServletSseServerTransportProvider.Builder contextExtractor( + McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

    + * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of HttpServletSseServerTransportProvider with the + * configured settings. + * @return A new HttpServletSseServerTransportProvider instance + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set + */ + public HttpServletSseServerTransportProvider build() { + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); + } + return new HttpServletSseServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval, contextExtractor); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java new file mode 100644 index 000000000..40767f416 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -0,0 +1,305 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.McpJsonMapper; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Mono; + +/** + * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@WebServlet(asyncSupported = true) +public class HttpServletStatelessServerTransport extends HttpServlet implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStatelessServerTransport.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String ACCEPT = "Accept"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Handles GET requests - returns 405 METHOD NOT ALLOWED as stateless transport + * doesn't support GET requests. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Both application/json and text/event-stream required in Accept header")); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponseText = jsonMapper.writeValueAsString(jsonrpcResponse); + PrintWriter writer = response.getWriter(); + writer.write(jsonResponseText); + writer.flush(); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Sends an error response to the client. + * @param response The HTTP servlet response + * @param httpCode The HTTP status code + * @param mcpError The MCP error to send + * @throws IOException If an I/O error occurs + */ + private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = jsonMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

    + * This method ensures a graceful shutdown before calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link HttpServletStatelessServerTransport}. + *

    + * This builder provides a fluent API for configuring and creating instances of + * HttpServletStatelessServerTransport with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Builder() { + // used by a static method + } + + /** + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The JsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStatelessServerTransport} with the + * configured settings. + * @return A new HttpServletStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStatelessServerTransport build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java new file mode 100644 index 000000000..34671c105 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -0,0 +1,851 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through HttpServlet. This implementation + * provides a bridge between synchronous HttpServlet operations and reactive programming + * patterns to maintain compatibility with the reactive transport interface. + * + *

    + * This is the HttpServlet equivalent of + * {@link io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider} + * for the core MCP module, providing streamable HTTP transport functionality without + * Spring dependencies. + * + * @author Zachary German + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see HttpServlet + */ +@WebServlet(asyncSupported = true) +public class HttpServletStreamableServerTransportProvider extends HttpServlet + implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Header name for the response media types accepted by the requester. + */ + private static final String ACCEPT = "Accept"; + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final McpJsonMapper jsonMapper; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new HttpServletStreamableServerTransportProvider instance. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization of + * messages. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @param contextExtractor The extractor for transport context from the request. + * @throws IllegalArgumentException if any parameter is null + */ + private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Handles GET requests to establish SSE connections and message replay. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + logger.debug("Handling GET request for session: {}", sessionId); + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + try { + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, asyncContext, response.getWriter()); + + // Check if this is a replay request + if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { + String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + asyncContext.complete(); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + asyncContext.complete(); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + asyncContext.addListener(new jakarta.servlet.AsyncListener() { + @Override + public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection timed out for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onError(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection error for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { + // No action needed + } + }); + } + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + if (accept == null || !accept.contains(APPLICATION_JSON)) { + badRequestErrors.add("application/json required in Accept header"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setHeader(HttpHeaders.MCP_SESSION_ID, init.session().getId()); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponse = jsonMapper.writeValueAsString(new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)); + + PrintWriter writer = response.getWriter(); + writer.write(jsonResponse); + writer.flush(); + return; + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to initialize session: " + e.getMessage())); + return; + } + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + this.responseError(response, HttpServletResponse.SC_NOT_FOUND, + new McpError("Session not found: " + sessionId)); + return; + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, asyncContext, response.getWriter()); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + asyncContext.complete(); + } + } + else { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Invalid message format: " + e.getMessage())); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + try { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Error processing message: " + e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + if (this.disallowDelete) { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + if (request.getHeader(HttpHeaders.MCP_SESSION_ID) == null) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Session ID required in mcp-session-id header")); + return; + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + try { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError(e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error deleting session"); + } + } + } + + public void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = jsonMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + /** + * Sends an SSE event to a client with a specific ID. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @param id The event ID + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data, String id) throws IOException { + if (id != null) { + writer.write("id: " + id + "\n"); + } + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

    + * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpStreamableServerTransport for HttpServlet SSE sessions. This + * class handles the transport-level communication for a specific client session. + * + *

    + * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying PrintWriter to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + + private class HttpServletStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + private volatile boolean closed = false; + + private final ReentrantLock lock = new ReentrantLock(); + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletStreamableMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Streamable session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = jsonMapper.writeValueAsString(message); + HttpServletStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText, + messageId != null ? messageId : this.sessionId); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + } + finally { + lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured JsonMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + HttpServletStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + finally { + lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of + * {@link HttpServletStreamableServerTransportProvider}. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Duration keepAliveInterval; + + /** + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The JsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if JsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be activated to periodically ping active sessions. + * @param keepAliveInterval The interval for keep-alive pings. If null, no + * keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} + * with the configured settings. + * @return A new HttpServletStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStreamableServerTransportProvider build() { + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + return new HttpServletStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java new file mode 100644 index 000000000..68be62931 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -0,0 +1,307 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * Implementation of the MCP Stdio transport provider for servers that communicates using + * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC + * messages over stdin/stdout, with errors and debug information sent to stderr. + * + * @author Christian Tzolov + */ +public class StdioServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); + + private final McpJsonMapper jsonMapper; + + private final InputStream inputStream; + + private final OutputStream outputStream; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * System streams. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization + */ + public StdioServerTransportProvider(McpJsonMapper jsonMapper) { + this(jsonMapper, System.in, System.out); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * streams. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization + * @param inputStream The input stream to read from + * @param outputStream The output stream to write to + */ + public StdioServerTransportProvider(McpJsonMapper jsonMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); + Assert.notNull(inputStream, "The InputStream can not be null"); + Assert.notNull(outputStream, "The OutputStream can not be null"); + + this.jsonMapper = jsonMapper; + this.inputStream = inputStream; + this.outputStream = outputStream; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // Create a single session for the stdio connection + var transport = new StdioMcpSessionTransport(); + this.session = sessionFactory.create(transport); + transport.initProcessing(); + } + + @Override + public Mono notifyClients(String method, Object params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class StdioMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling inbound messages */ + private Scheduler inboundScheduler; + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public StdioMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Use bounded schedulers for better resource management + this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-outbound"); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void initProcessing() { + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + this.inboundScheduler.schedule(() -> { + inboundReady.tryEmitValue(null); + BufferedReader reader = null; + try { + reader = new BufferedReader(new InputStreamReader(inputStream)); + while (!isClosing.get()) { + try { + String line = reader.readLine(); + if (line == null || isClosing.get()) { + break; + } + + logger.debug("Received JSON message: {}", line); + + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, + line); + if (!this.inboundSink.tryEmitNext(message).isSuccess()) { + // logIfNotClosing("Failed to enqueue message"); + break; + } + + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + break; + } + } + catch (IOException e) { + logIfNotClosing("Error reading from stdin", e); + break; + } + } + } + catch (Exception e) { + logIfNotClosing("Error in inbound processing", e); + } + finally { + isClosing.set(true); + if (session != null) { + session.close(); + } + inboundSink.tryEmitComplete(); + } + }); + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + String jsonMessage = jsonMapper.writeValueAsString(message); + // Escape any embedded newlines in the JSON message as per spec + jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); + + synchronized (outputStream) { + outputStream.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); + outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + } + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java new file mode 100644 index 000000000..b18364abb --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ +package io.modelcontextprotocol.spec; + +import java.util.Optional; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Represents a closed MCP session, which may not be reused. All calls will throw a + * {@link McpTransportSessionClosedException}. + * + * @param the resource representing the connection that the transport + * manages. + * @author Daniel Garnier-Moiroux + */ +public class ClosedMcpTransportSession implements McpTransportSession { + + private final String sessionId; + + public ClosedMcpTransportSession(@Nullable String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Optional sessionId() { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public boolean markInitialized(String sessionId) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void addConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void removeConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void close() { + + } + + @Override + public Publisher closeGracefully() { + return Mono.empty(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java new file mode 100644 index 000000000..f497afd43 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; + +/** + * A default implementation of {@link McpStreamableServerSession.Factory}. + * + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpStreamableServerSessionFactory implements McpStreamableServerSession.Factory { + + Duration requestTimeout; + + McpStreamableServerSession.InitRequestHandler initRequestHandler; + + Map> requestHandlers; + + Map notificationHandlers; + + /** + * Constructs an instance + * @param requestTimeout timeout for requests + * @param initRequestHandler initialization request handler + * @param requestHandlers map of MCP request handlers keyed by method name + * @param notificationHandlers map of MCP notification handlers keyed by method name + */ + public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, + McpStreamableServerSession.InitRequestHandler initRequestHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initRequestHandler = initRequestHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession( + McpSchema.InitializeRequest initializeRequest) { + return new McpStreamableServerSession.McpStreamableServerSessionInit( + new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), + initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers), + this.initRequestHandler.handle(initializeRequest)); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java new file mode 100644 index 000000000..fdb7bfd89 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -0,0 +1,84 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Mono; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * Default implementation of {@link McpTransportSession} which manages the open + * connections using tye {@link Disposable} type and allows to perform clean up using the + * {@link Disposable#dispose()} method. + * + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpTransportSession implements McpTransportSession { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpTransportSession.class); + + private final Disposable.Composite openConnections = Disposables.composite(); + + private final AtomicBoolean initialized = new AtomicBoolean(false); + + private final AtomicReference sessionId = new AtomicReference<>(); + + private final Function> onClose; + + public DefaultMcpTransportSession(Function> onClose) { + this.onClose = onClose; + } + + @Override + public Optional sessionId() { + return Optional.ofNullable(this.sessionId.get()); + } + + @Override + public boolean markInitialized(String sessionId) { + boolean flipped = this.initialized.compareAndSet(false, true); + if (flipped) { + this.sessionId.set(sessionId); + logger.debug("Established session with id {}", sessionId); + } + else { + if (sessionId != null && !sessionId.equals(this.sessionId.get())) { + logger.warn("Different session id provided in response. Expecting {} but server returned {}", + this.sessionId.get(), sessionId); + } + } + return flipped; + } + + @Override + public void addConnection(Disposable connection) { + this.openConnections.add(connection); + } + + @Override + public void removeConnection(Disposable connection) { + this.openConnections.remove(connection); + } + + @Override + public void close() { + this.closeGracefully().subscribe(); + } + + @Override + public Mono closeGracefully() { + return Mono.from(this.onClose.apply(this.sessionId.get())) + .then(Mono.fromRunnable(this.openConnections::dispose)); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java new file mode 100644 index 000000000..8d63fb50d --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -0,0 +1,83 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * An implementation of {@link McpTransportStream} using Project Reactor types. + * + * @param the resource serving the stream + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpTransportStream implements McpTransportStream { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpTransportStream.class); + + private static final AtomicLong counter = new AtomicLong(); + + private final AtomicReference lastId = new AtomicReference<>(); + + // Used only for internal accounting + private final long streamId; + + private final boolean resumable; + + private final Function, Publisher> reconnect; + + /** + * Constructs a new instance representing a particular stream that can resume using + * the provided reconnect mechanism. + * @param resumable whether the stream is resumable and should try to reconnect + * @param reconnect the mechanism to use in case an error is observed on the current + * event stream to asynchronously kick off a resumed stream consumption, potentially + * using the stored {@link #lastId()}. + */ + public DefaultMcpTransportStream(boolean resumable, + Function, Publisher> reconnect) { + this.reconnect = reconnect; + this.streamId = counter.getAndIncrement(); + this.resumable = resumable; + } + + @Override + public Optional lastId() { + return Optional.ofNullable(this.lastId.get()); + } + + @Override + public long streamId() { + return this.streamId; + } + + @Override + public Publisher consumeSseStream( + Publisher, Iterable>> eventStream) { + + // @formatter:off + return Flux.deferContextual(ctx -> Flux.from(eventStream) + .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })) + .doOnError(e -> { + if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { + Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); + } + }) + .flatMapIterable(Tuple2::getT2)); // @formatter:on + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java new file mode 100644 index 000000000..6afc2c119 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Names of HTTP headers in use by MCP HTTP transports. + * + * @author Dariusz Jędrzejczyk + */ +public interface HttpHeaders { + + /** + * Identifies individual MCP sessions. + */ + String MCP_SESSION_ID = "Mcp-Session-Id"; + + /** + * Identifies events within an SSE Stream. + */ + String LAST_EVENT_ID = "Last-Event-ID"; + + /** + * Identifies the MCP protocol version. + */ + String PROTOCOL_VERSION = "MCP-Protocol-Version"; + + /** + * The HTTP Content-Length header. + * @see RFC9110 + */ + String CONTENT_LENGTH = "Content-Length"; + + /** + * The HTTP Content-Type header. + * @see RFC9110 + */ + String CONTENT_TYPE = "Content-Type"; + + /** + * The HTTP Accept header. + * @see RFC9110 + */ + String ACCEPT = "Accept"; + + /** + * The HTTP Cache-Control header. + * @see RFC9111 + */ + String CACHE_CONTROL = "Cache-Control"; + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java new file mode 100644 index 000000000..4a42c9ff3 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.Map; + +/** + * Interface for validating structured content against a JSON schema. This interface + * defines a method to validate structured content based on the provided output schema. + * + * @author Christian Tzolov + */ +public interface JsonSchemaValidator { + + /** + * Represents the result of a validation operation. + * + * @param valid Indicates whether the validation was successful. + * @param errorMessage An error message if the validation failed, otherwise null. + * @param jsonStructuredOutput The text structured content in JSON format if the + * validation was successful, otherwise null. + */ + public record ValidationResponse(boolean valid, String errorMessage, String jsonStructuredOutput) { + + public static ValidationResponse asValid(String jsonStructuredOutput) { + return new ValidationResponse(true, null, jsonStructuredOutput); + } + + public static ValidationResponse asInvalid(String message) { + return new ValidationResponse(false, message, null); + } + } + + /** + * Validates the structured content against the provided JSON schema. + * @param schema The JSON schema to validate against. + * @param structuredContent The structured content to validate. + * @return A ValidationResponse indicating whether the validation was successful or + * not. + */ + ValidationResponse validate(Map schema, Object structuredContent); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java similarity index 60% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index e2d354f4a..0ba7ab3b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -4,20 +4,21 @@ package io.modelcontextprotocol.spec; -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.Disposable; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; + /** * Default implementation of the MCP (Model Context Protocol) session that manages * bidirectional JSON-RPC communication between clients and servers. This implementation @@ -34,17 +35,17 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Yanming Zhou */ -public class DefaultMcpSession implements McpSession { +public class McpClientSession implements McpSession { - /** Logger for this class */ - private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSession.class); + private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); /** Duration to wait for request responses before timing out */ private final Duration requestTimeout; /** Transport layer implementation for message exchange */ - private final McpTransport transport; + private final McpClientTransport transport; /** Map of pending responses keyed by request ID */ private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); @@ -61,8 +62,6 @@ public class DefaultMcpSession implements McpSession { /** Atomic counter for generating unique request IDs */ private final AtomicLong requestCounter = new AtomicLong(0); - private final Disposable connection; - /** * Functional interface for handling incoming JSON-RPC requests. Implementations * should process the request parameters and return a response. @@ -98,16 +97,34 @@ public interface NotificationHandler { } /** - * Creates a new DefaultMcpSession with the specified configuration and handlers. + * Creates a new McpClientSession with the specified configuration and handlers. * @param requestTimeout Duration to wait for responses * @param transport Transport implementation for message exchange * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers + * @deprecated Use + * {@link #McpClientSession(Duration, McpClientTransport, Map, Map, Function)} */ - public DefaultMcpSession(Duration requestTimeout, McpTransport transport, + @Deprecated + public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { + this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity()); + } + + /** + * Creates a new McpClientSession with the specified configuration and handlers. + * @param requestTimeout Duration to wait for responses + * @param transport Transport implementation for message exchange + * @param requestHandlers Map of method names to request handlers + * @param notificationHandlers Map of method names to notification handlers + * @param connectHook Hook that allows transforming the connection Publisher prior to + * subscribing + */ + public McpClientSession(Duration requestTimeout, McpClientTransport transport, + Map> requestHandlers, Map notificationHandlers, + Function, ? extends Publisher> connectHook) { - Assert.notNull(requestTimeout, "The requstTimeout can not be null"); + Assert.notNull(requestTimeout, "The requestTimeout can not be null"); Assert.notNull(transport, "The transport can not be null"); Assert.notNull(requestHandlers, "The requestHandlers can not be null"); Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); @@ -117,38 +134,63 @@ public DefaultMcpSession(Duration requestTimeout, McpTransport transport, this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); - // TODO: consider mono.transformDeferredContextual where the Context contains - // the - // Observation associated with the individual message - it can be used to - // create child Observation and emit it together with the message to the - // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { - if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); + this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(); + } + + private void dismissPendingResponses() { + this.pendingResponses.forEach((id, sink) -> { + logger.warn("Abruptly terminating exchange for request {}", id); + sink.error(new RuntimeException("MCP session with server terminated")); + }); + this.pendingResponses.clear(); + } + + private void handle(McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received response: {}", response); + if (response.id() != null) { var sink = pendingResponses.remove(response.id()); if (sink == null) { - logger.warn("Unexpected response for unkown id {}", response.id()); + logger.warn("Unexpected response for unknown id {}", response.id()); } else { sink.success(response); } } - else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.debug("Received request: {}", request); - handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), - error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - transport.sendMessage(errorResponse).subscribe(); - }); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification).subscribe(null, - error -> logger.error("Error handling notification: {}", error.getMessage())); + else { + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } - })).subscribe(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + handleIncomingRequest(request).onErrorResume(error -> { + + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + jsonRpcError); + return Mono.just(errorResponse); + }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { + logger.warn("Issue sending response to the client, ", t); + return true; + }).subscribe(); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + handleIncomingNotification(notification).onErrorComplete(t -> { + logger.error("Error handling notification: {}", t.getMessage()); + return true; + }).subscribe(); + } + else { + logger.warn("Received unknown message type: {}", message); + } } /** @@ -167,18 +209,14 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return handler.handle(request.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)); }); } record MethodNotFoundError(String method, String message, Object data) { } - public static MethodNotFoundError getMethodNotFoundError(String method) { + private MethodNotFoundError getMethodNotFoundError(String method) { switch (method) { case McpSchema.METHOD_ROOTS_LIST: return new MethodNotFoundError(method, "Roots not supported", @@ -197,7 +235,7 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { var handler = notificationHandlers.get(notification.method()); if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); + logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } return handler.handle(notification.params()); @@ -222,30 +260,30 @@ private String generateRequestId() { * @return A Mono containing the response */ @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); - return Mono.create(sink -> { - this.pendingResponses.put(requestId, sink); + return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { + logger.debug("Sending message for method {}", method); + this.pendingResponses.put(requestId, pendingResponseSink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest) - // TODO: It's most efficient to create a dedicated Subscriber here - .subscribe(v -> { - }, error -> { - this.pendingResponses.remove(requestId); - sink.error(error); - }); - }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + pendingResponseSink.error(error); + }); + })).timeout(this.requestTimeout).handle((jsonRpcResponse, deliveredResponseSink) -> { if (jsonRpcResponse.error() != null) { - sink.error(new McpError(jsonRpcResponse.error())); + logger.error("Error handling request: {}", jsonRpcResponse.error()); + deliveredResponseSink.error(new McpError(jsonRpcResponse.error())); } else { if (typeRef.getType().equals(Void.class)) { - sink.complete(); + deliveredResponseSink.complete(); } else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + deliveredResponseSink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); } } }); @@ -258,7 +296,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc * @return A Mono that completes when the notification is sent */ @Override - public Mono sendNotification(String method, Map params) { + public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); return this.transport.sendMessage(jsonrpcNotification); @@ -270,8 +308,7 @@ public Mono sendNotification(String method, Map params) { */ @Override public Mono closeGracefully() { - this.connection.dispose(); - return transport.closeGracefully(); + return Mono.fromRunnable(this::dismissPendingResponses); } /** @@ -279,8 +316,7 @@ public Mono closeGracefully() { */ @Override public void close() { - this.connection.dispose(); - transport.close(); + dismissPendingResponses(); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java new file mode 100644 index 000000000..22aec831b --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -0,0 +1,41 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import java.util.function.Consumer; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +/** + * Interface for the client side of the {@link McpTransport}. It allows setting handlers + * for messages that are incoming from the MCP server and hooking in to exceptions raised + * on the transport layer. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpClientTransport extends McpTransport { + + /** + * Used to register the incoming messages' handler and potentially (eagerly) connect + * to the server. + * @param handler a transformer for incoming messages + * @return a {@link Mono} that terminates upon successful client setup. It can mean + * establishing a connection (which can be later disposed) but it doesn't have to, + * depending on the transport type. The successful termination of the returned + * {@link Mono} simply means the client can now be used. An error can be retried + * according to the application requirements. + */ + Mono connect(Function, Mono> handler); + + /** + * Sets the exception handler for exceptions raised on the transport layer. + * @param handler Allows reacting to transport level exceptions by the higher layers + */ + default void setExceptionHandler(Consumer handler) { + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java new file mode 100644 index 000000000..d6e549fdc --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -0,0 +1,116 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; +import io.modelcontextprotocol.util.Assert; + +import java.util.Map; +import java.util.function.Function; + +public class McpError extends RuntimeException { + + /** + * Resource + * Error Handling + */ + public static final Function RESOURCE_NOT_FOUND = resourceUri -> new McpError(new JSONRPCError( + McpSchema.ErrorCodes.RESOURCE_NOT_FOUND, "Resource not found", Map.of("uri", resourceUri))); + + private JSONRPCError jsonRpcError; + + public McpError(JSONRPCError jsonRpcError) { + super(jsonRpcError.message()); + this.jsonRpcError = jsonRpcError; + } + + @Deprecated + public McpError(Object error) { + super(error.toString()); + } + + public JSONRPCError getJsonRpcError() { + return jsonRpcError; + } + + @Override + public String toString() { + var builder = new StringBuilder(super.toString()); + if (jsonRpcError != null) { + builder.append("\n"); + builder.append(jsonRpcError.toString()); + } + return builder.toString(); + } + + public static Builder builder(int errorCode) { + return new Builder(errorCode); + } + + public static class Builder { + + private final int code; + + private String message; + + private Object data; + + private Builder(int code) { + this.code = code; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder data(Object data) { + this.data = data; + return this; + } + + public McpError build() { + Assert.hasText(message, "message must not be empty"); + return new McpError(new JSONRPCError(code, message, data)); + } + + } + + public static Throwable findRootCause(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + Throwable rootCause = throwable; + while (rootCause.getCause() != null && rootCause.getCause() != rootCause) { + rootCause = rootCause.getCause(); + } + return rootCause; + } + + public static String aggregateExceptionMessages(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + + StringBuilder messages = new StringBuilder(); + Throwable current = throwable; + + while (current != null) { + if (messages.length() > 0) { + messages.append("\n Caused by: "); + } + + messages.append(current.getClass().getSimpleName()); + if (current.getMessage() != null) { + messages.append(": ").append(current.getMessage()); + } + + if (current.getCause() == current) { + break; + } + current = current.getCause(); + } + + return messages.toString(); + } + +} \ No newline at end of file diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java new file mode 100644 index 000000000..f43a2c1d9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java @@ -0,0 +1,29 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * An {@link McpSession} which is capable of processing logging notifications and keeping + * track of a min logging level. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpLoggableSession extends McpSession { + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel); + + /** + * Allows checking whether a particular logging level is allowed. + * @param loggingLevel the level to check + * @return whether the logging at the specified level is permitted. + */ + boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java new file mode 100644 index 000000000..b58f1c552 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -0,0 +1,2931 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Based on the JSON-RPC 2.0 + * specification and the Model + * Context Protocol Schema. + * + * @author Christian Tzolov + * @author Luca Chang + * @author Surbhi Bansal + * @author Anurag Pant + */ +public final class McpSchema { + + private static final Logger logger = LoggerFactory.getLogger(McpSchema.class); + + private McpSchema() { + } + + @Deprecated + public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_06_18; + + public static final String JSONRPC_VERSION = "2.0"; + + public static final String FIRST_PAGE = null; + + // --------------------------- + // Method Names + // --------------------------- + + // Lifecycle Methods + public static final String METHOD_INITIALIZE = "initialize"; + + public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized"; + + public static final String METHOD_PING = "ping"; + + public static final String METHOD_NOTIFICATION_PROGRESS = "notifications/progress"; + + // Tool Methods + public static final String METHOD_TOOLS_LIST = "tools/list"; + + public static final String METHOD_TOOLS_CALL = "tools/call"; + + public static final String METHOD_NOTIFICATION_TOOLS_LIST_CHANGED = "notifications/tools/list_changed"; + + // Resources Methods + public static final String METHOD_RESOURCES_LIST = "resources/list"; + + public static final String METHOD_RESOURCES_READ = "resources/read"; + + public static final String METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED = "notifications/resources/list_changed"; + + public static final String METHOD_NOTIFICATION_RESOURCES_UPDATED = "notifications/resources/updated"; + + public static final String METHOD_RESOURCES_TEMPLATES_LIST = "resources/templates/list"; + + public static final String METHOD_RESOURCES_SUBSCRIBE = "resources/subscribe"; + + public static final String METHOD_RESOURCES_UNSUBSCRIBE = "resources/unsubscribe"; + + // Prompt Methods + public static final String METHOD_PROMPT_LIST = "prompts/list"; + + public static final String METHOD_PROMPT_GET = "prompts/get"; + + public static final String METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"; + + public static final String METHOD_COMPLETION_COMPLETE = "completion/complete"; + + // Logging Methods + public static final String METHOD_LOGGING_SET_LEVEL = "logging/setLevel"; + + public static final String METHOD_NOTIFICATION_MESSAGE = "notifications/message"; + + // Roots Methods + public static final String METHOD_ROOTS_LIST = "roots/list"; + + public static final String METHOD_NOTIFICATION_ROOTS_LIST_CHANGED = "notifications/roots/list_changed"; + + // Sampling Methods + public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + + // --------------------------- + // JSON-RPC Error Codes + // --------------------------- + /** + * Standard error codes used in MCP JSON-RPC responses. + */ + public static final class ErrorCodes { + + /** + * Invalid JSON was received by the server. + */ + public static final int PARSE_ERROR = -32700; + + /** + * The JSON sent is not a valid Request object. + */ + public static final int INVALID_REQUEST = -32600; + + /** + * The method does not exist / is not available. + */ + public static final int METHOD_NOT_FOUND = -32601; + + /** + * Invalid method parameter(s). + */ + public static final int INVALID_PARAMS = -32602; + + /** + * Internal JSON-RPC error. + */ + public static final int INTERNAL_ERROR = -32603; + + /** + * Resource not found. + */ + public static final int RESOURCE_NOT_FOUND = -32002; + + } + + /** + * Base interface for MCP objects that include optional metadata in the `_meta` field. + */ + public interface Meta { + + /** + * @see Specification + * for notes on _meta usage + * @return additional metadata related to this resource. + */ + Map meta(); + + } + + public sealed interface Request extends Meta + permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, + GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + + default Object progressToken() { + if (meta() != null && meta().containsKey("progressToken")) { + return meta().get("progressToken"); + } + return null; + } + + } + + public sealed interface Result extends Meta permits InitializeResult, ListResourcesResult, + ListResourceTemplatesResult, ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, + CallToolResult, CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { + + } + + public sealed interface Notification extends Meta + permits ProgressNotification, LoggingMessageNotification, ResourcesUpdatedNotification { + + } + + private static final TypeRef> MAP_TYPE_REF = new TypeRef<>() { + }; + + /** + * Deserializes a JSON string into a JSONRPCMessage object. + * @param jsonMapper The JsonMapper instance to use for deserialization + * @param jsonText The JSON string to deserialize + * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, + * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. + * @throws IOException If there's an error during deserialization + * @throws IllegalArgumentException If the JSON structure doesn't match any known + * message type + */ + public static JSONRPCMessage deserializeJsonRpcMessage(McpJsonMapper jsonMapper, String jsonText) + throws IOException { + + logger.debug("Received JSON message: {}", jsonText); + + var map = jsonMapper.readValue(jsonText, MAP_TYPE_REF); + + // Determine message type based on specific JSON structure + if (map.containsKey("method") && map.containsKey("id")) { + return jsonMapper.convertValue(map, JSONRPCRequest.class); + } + else if (map.containsKey("method") && !map.containsKey("id")) { + return jsonMapper.convertValue(map, JSONRPCNotification.class); + } + else if (map.containsKey("result") || map.containsKey("error")) { + return jsonMapper.convertValue(map, JSONRPCResponse.class); + } + + throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); + } + + // --------------------------- + // JSON-RPC Message Types + // --------------------------- + public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { + + String jsonrpc(); + + } + + /** + * A request that expects a response. + * + * @param jsonrpc The JSON-RPC version (must be "2.0") + * @param method The name of the method to be invoked + * @param id A unique identifier for the request + * @param params Parameters for the method call + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + public record JSONRPCRequest( // @formatter:off + @JsonProperty("jsonrpc") String jsonrpc, + @JsonProperty("method") String method, + @JsonProperty("id") Object id, + @JsonProperty("params") Object params) implements JSONRPCMessage { // @formatter:on + + /** + * Constructor that validates MCP-specific ID requirements. Unlike base JSON-RPC, + * MCP requires that: (1) Requests MUST include a string or integer ID; (2) The ID + * MUST NOT be null + */ + public JSONRPCRequest { + Assert.notNull(id, "MCP requests MUST include an ID - null IDs are not allowed"); + Assert.isTrue(id instanceof String || id instanceof Integer || id instanceof Long, + "MCP requests MUST have an ID that is either a string or integer"); + } + } + + /** + * A notification which does not expect a response. + * + * @param jsonrpc The JSON-RPC version (must be "2.0") + * @param method The name of the method being notified + * @param params Parameters for the notification + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + public record JSONRPCNotification( // @formatter:off + @JsonProperty("jsonrpc") String jsonrpc, + @JsonProperty("method") String method, + @JsonProperty("params") Object params) implements JSONRPCMessage { // @formatter:on + } + + /** + * A response to a request (successful, or error). + * + * @param jsonrpc The JSON-RPC version (must be "2.0") + * @param id The request identifier that this response corresponds to + * @param result The result of the successful request; null if error + * @param error Error information if the request failed; null if has result + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + public record JSONRPCResponse( // @formatter:off + @JsonProperty("jsonrpc") String jsonrpc, + @JsonProperty("id") Object id, + @JsonProperty("result") Object result, + @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { // @formatter:on + + /** + * A response to a request that indicates an error occurred. + * + * @param code The error type that occurred + * @param message A short description of the error. The message SHOULD be limited + * to a concise single sentence + * @param data Additional information about the error. The value of this member is + * defined by the sender (e.g. detailed error information, nested errors etc.) + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCError( // @formatter:off + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { // @formatter:on + } + } + + // --------------------------- + // Initialization + // --------------------------- + /** + * This request is sent from the client to the server when it first connects, asking + * it to begin initialization. + * + * @param protocolVersion The latest version of the Model Context Protocol that the + * client supports. The client MAY decide to support older versions as well + * @param capabilities The capabilities that the client supports + * @param clientInfo Information about the client implementation + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record InitializeRequest( // @formatter:off + @JsonProperty("protocolVersion") String protocolVersion, + @JsonProperty("capabilities") ClientCapabilities capabilities, + @JsonProperty("clientInfo") Implementation clientInfo, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public InitializeRequest(String protocolVersion, ClientCapabilities capabilities, Implementation clientInfo) { + this(protocolVersion, capabilities, clientInfo, null); + } + } + + /** + * After receiving an initialize request from the client, the server sends this + * response. + * + * @param protocolVersion The version of the Model Context Protocol that the server + * wants to use. This may not match the version that the client requested. If the + * client cannot support this version, it MUST disconnect + * @param capabilities The capabilities that the server supports + * @param serverInfo Information about the server implementation + * @param instructions Instructions describing how to use the server and its features. + * This can be used by clients to improve the LLM's understanding of available tools, + * resources, etc. It can be thought of like a "hint" to the model. For example, this + * information MAY be added to the system prompt + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record InitializeResult( // @formatter:off + @JsonProperty("protocolVersion") String protocolVersion, + @JsonProperty("capabilities") ServerCapabilities capabilities, + @JsonProperty("serverInfo") Implementation serverInfo, + @JsonProperty("instructions") String instructions, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public InitializeResult(String protocolVersion, ServerCapabilities capabilities, Implementation serverInfo, + String instructions) { + this(protocolVersion, capabilities, serverInfo, instructions, null); + } + } + + /** + * Capabilities a client may support. Known capabilities are defined here, in this + * schema, but this is not a closed set: any client can define its own, additional + * capabilities. + * + * @param experimental Experimental, non-standard capabilities that the client + * supports + * @param roots Present if the client supports listing roots + * @param sampling Present if the client supports sampling from an LLM + * @param elicitation Present if the client supports elicitation from the server + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ClientCapabilities( // @formatter:off + @JsonProperty("experimental") Map experimental, + @JsonProperty("roots") RootCapabilities roots, + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { // @formatter:on + + /** + * Present if the client supports listing roots. + * + * @param listChanged Whether the client supports notifications for changes to the + * roots list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record RootCapabilities(@JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Provides a standardized way for servers to request LLM sampling ("completions" + * or "generations") from language models via clients. This flow allows clients to + * maintain control over model access, selection, and permissions while enabling + * servers to leverage AI capabilities—with no server API keys necessary. Servers + * can request text or image-based interactions and optionally include context + * from MCP servers in their prompts. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Sampling() { + } + + /** + * Provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to + * maintain control over user interactions and data sharing while enabling servers + * to gather necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + * + *

    + * Per the 2025-11-25 spec, clients can declare support for specific elicitation + * modes: + *

      + *
    • {@code form} - In-band structured data collection with optional schema + * validation
    • + *
    • {@code url} - Out-of-band interaction via URL navigation
    • + *
    + * + *

    + * For backward compatibility, an empty elicitation object {@code {}} is + * equivalent to declaring support for form mode only. + * + * @param form support for in-band form-based elicitation + * @param url support for out-of-band URL-based elicitation + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation(@JsonProperty("form") Form form, @JsonProperty("url") Url url) { + + /** + * Marker record indicating support for form-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Form() { + } + + /** + * Marker record indicating support for URL-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Url() { + } + + /** + * Creates an Elicitation with default settings (backward compatible, produces + * empty JSON object). + */ + public Elicitation() { + this(null, null); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Map experimental; + + private RootCapabilities roots; + + private Sampling sampling; + + private Elicitation elicitation; + + public Builder experimental(Map experimental) { + this.experimental = experimental; + return this; + } + + public Builder roots(Boolean listChanged) { + this.roots = new RootCapabilities(listChanged); + return this; + } + + public Builder sampling() { + this.sampling = new Sampling(); + return this; + } + + /** + * Enables elicitation capability with default settings (backward compatible, + * produces empty JSON object). + * @return this builder + */ + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + + /** + * Enables elicitation capability with explicit form and/or url mode support. + * @param form whether to support form-based elicitation + * @param url whether to support URL-based elicitation + * @return this builder + */ + public Builder elicitation(boolean form, boolean url) { + this.elicitation = new Elicitation(form ? new Elicitation.Form() : null, + url ? new Elicitation.Url() : null); + return this; + } + + public ClientCapabilities build() { + return new ClientCapabilities(experimental, roots, sampling, elicitation); + } + + } + } + + /** + * Capabilities that a server may support. Known capabilities are defined here, in + * this schema, but this is not a closed set: any server can define its own, + * additional capabilities. + * + * @param completions Present if the server supports argument autocompletion + * suggestions + * @param experimental Experimental, non-standard capabilities that the server + * supports + * @param logging Present if the server supports sending log messages to the client + * @param prompts Present if the server offers any prompt templates + * @param resources Present if the server offers any resources to read + * @param tools Present if the server offers any tools to call + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ServerCapabilities( // @formatter:off + @JsonProperty("completions") CompletionCapabilities completions, + @JsonProperty("experimental") Map experimental, + @JsonProperty("logging") LoggingCapabilities logging, + @JsonProperty("prompts") PromptCapabilities prompts, + @JsonProperty("resources") ResourceCapabilities resources, + @JsonProperty("tools") ToolCapabilities tools) { // @formatter:on + + /** + * Present if the server supports argument autocompletion suggestions. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record CompletionCapabilities() { + } + + /** + * Present if the server supports sending log messages to the client. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record LoggingCapabilities() { + } + + /** + * Present if the server offers any prompt templates. + * + * @param listChanged Whether this server supports notifications for changes to + * the prompt list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record PromptCapabilities(@JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Present if the server offers any resources to read. + * + * @param subscribe Whether this server supports subscribing to resource updates + * @param listChanged Whether this server supports notifications for changes to + * the resource list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record ResourceCapabilities(@JsonProperty("subscribe") Boolean subscribe, + @JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Present if the server offers any tools to call. + * + * @param listChanged Whether this server supports notifications for changes to + * the tool list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record ToolCapabilities(@JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Create a mutated copy of this object with the specified changes. + * @return A new Builder instance with the same values as this object. + */ + public Builder mutate() { + var builder = new Builder(); + builder.completions = this.completions; + builder.experimental = this.experimental; + builder.logging = this.logging; + builder.prompts = this.prompts; + builder.resources = this.resources; + builder.tools = this.tools; + return builder; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private CompletionCapabilities completions; + + private Map experimental; + + private LoggingCapabilities logging; + + private PromptCapabilities prompts; + + private ResourceCapabilities resources; + + private ToolCapabilities tools; + + public Builder completions() { + this.completions = new CompletionCapabilities(); + return this; + } + + public Builder experimental(Map experimental) { + this.experimental = experimental; + return this; + } + + public Builder logging() { + this.logging = new LoggingCapabilities(); + return this; + } + + public Builder prompts(Boolean listChanged) { + this.prompts = new PromptCapabilities(listChanged); + return this; + } + + public Builder resources(Boolean subscribe, Boolean listChanged) { + this.resources = new ResourceCapabilities(subscribe, listChanged); + return this; + } + + public Builder tools(Boolean listChanged) { + this.tools = new ToolCapabilities(listChanged); + return this; + } + + public ServerCapabilities build() { + return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools); + } + + } + } + + /** + * Describes the name and version of an MCP implementation, with an optional title for + * UI representation. + * + * @param name Intended for programmatic or logical use, but used as a display name in + * past specs or fallback (if title isn't present). + * @param title Intended for UI and end-user contexts + * @param version The version of the implementation. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Implementation( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("version") String version) implements Identifier { // @formatter:on + + public Implementation(String name, String version) { + this(name, null, version); + } + } + + // Existing Enums and Base Types (from previous implementation) + public enum Role { + + // @formatter:off + @JsonProperty("user") USER, + @JsonProperty("assistant") ASSISTANT + } // @formatter:on + + // --------------------------- + // Resource Interfaces + // --------------------------- + /** + * Base for objects that include optional annotations for the client. The client can + * use annotations to inform how objects are used or displayed + */ + public interface Annotated { + + Annotations annotations(); + + } + + /** + * Optional annotations for the client. The client can use annotations to inform how + * objects are used or displayed. + * + * @param audience Describes who the intended customer of this object or data is. It + * can include multiple entries to indicate content useful for multiple audiences + * (e.g., `["user", "assistant"]`). + * @param priority Describes how important this data is for operating the server. A + * value of 1 means "most important," and indicates that the data is effectively + * required, while 0 means "least important," and indicates that the data is entirely + * optional. It is a number between 0 and 1. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Annotations( // @formatter:off + @JsonProperty("audience") List audience, + @JsonProperty("priority") Double priority, + @JsonProperty("lastModified") String lastModified + ) { // @formatter:on + + public Annotations(List audience, Double priority) { + this(audience, priority, null); + } + } + + /** + * A common interface for resource content, which includes metadata about the resource + * such as its URI, name, description, MIME type, size, and annotations. This + * interface is implemented by both {@link Resource} and {@link ResourceLink} to + * provide a consistent way to access resource metadata. + */ + public interface ResourceContent extends Identifier, Annotated, Meta { + + // name & title from Identifier + + String uri(); + + String description(); + + String mimeType(); + + Long size(); + + // annotations from Annotated + // meta from Meta + + } + + /** + * Base interface with name (identifier) and title (display name) properties. + */ + public interface Identifier { + + /** + * Intended for programmatic or logical use, but used as a display name in past + * specs or fallback (if title isn't present). + */ + String name(); + + /** + * Intended for UI and end-user contexts — optimized to be human-readable and + * easily understood, even by those unfamiliar with domain-specific terminology. + * + * If not provided, the name should be used for display. + */ + String title(); + + } + + /** + * A known resource that the server is capable of reading. + * + * @param uri the URI of the resource. + * @param name A human-readable name for this resource. This can be used by clients to + * populate UI elements. + * @param title An optional title for this resource. + * @param description A description of what this resource represents. This can be used + * by clients to improve the LLM's understanding of available resources. It can be + * thought of like a "hint" to the model. + * @param mimeType The MIME type of this resource, if known. + * @param size The size of the raw resource content, in bytes (i.e., before base64 + * encoding or any tokenization), if known. This can be used by Hosts to display file + * sizes and estimate context window usage. + * @param annotations Optional annotations for the client. The client can use + * annotations to inform how objects are used or displayed. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Resource( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("size") Long size, + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("_meta") Map meta) implements ResourceContent { // @formatter:on + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link Resource#builder()} instead. + */ + @Deprecated + public Resource(String uri, String name, String title, String description, String mimeType, Long size, + Annotations annotations) { + this(uri, name, title, description, mimeType, size, annotations, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link Resource#builder()} instead. + */ + @Deprecated + public Resource(String uri, String name, String description, String mimeType, Long size, + Annotations annotations) { + this(uri, name, null, description, mimeType, size, annotations, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link Resource#builder()} instead. + */ + @Deprecated + public Resource(String uri, String name, String description, String mimeType, Annotations annotations) { + this(uri, name, null, description, mimeType, null, annotations, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uri; + + private String name; + + private String title; + + private String description; + + private String mimeType; + + private Long size; + + private Annotations annotations; + + private Map meta; + + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder size(Long size) { + this.size = size; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Resource build() { + Assert.hasText(uri, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new Resource(uri, name, title, description, mimeType, size, annotations, meta); + } + + } + } + + /** + * Resource templates allow servers to expose parameterized resources using URI + * + * @param uriTemplate A URI template that can be used to generate URIs for this + * resource. + * @param name A human-readable name for this resource. This can be used by clients to + * populate UI elements. + * @param title An optional title for this resource. + * @param description A description of what this resource represents. This can be used + * by clients to improve the LLM's understanding of available resources. It can be + * thought of like a "hint" to the model. + * @param mimeType The MIME type of this resource, if known. + * @param annotations Optional annotations for the client. The client can use + * annotations to inform how objects are used or displayed. + * @see RFC 6570 + * @param meta See specification for notes on _meta usage + * + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceTemplate( // @formatter:off + @JsonProperty("uriTemplate") String uriTemplate, + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("_meta") Map meta) implements Annotated, Identifier, Meta { // @formatter:on + + public ResourceTemplate(String uriTemplate, String name, String title, String description, String mimeType, + Annotations annotations) { + this(uriTemplate, name, title, description, mimeType, annotations, null); + } + + public ResourceTemplate(String uriTemplate, String name, String description, String mimeType, + Annotations annotations) { + this(uriTemplate, name, null, description, mimeType, annotations); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uriTemplate; + + private String name; + + private String title; + + private String description; + + private String mimeType; + + private Annotations annotations; + + private Map meta; + + public Builder uriTemplate(String uri) { + this.uriTemplate = uri; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ResourceTemplate build() { + Assert.hasText(uriTemplate, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new ResourceTemplate(uriTemplate, name, title, description, mimeType, annotations, meta); + } + + } + } + + /** + * The server's response to a resources/list request from the client. + * + * @param resources A list of resources that the server provides + * @param nextCursor An opaque token representing the pagination position after the + * last returned result. If present, there may be more results available + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListResourcesResult( // @formatter:off + @JsonProperty("resources") List resources, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListResourcesResult(List resources, String nextCursor) { + this(resources, nextCursor, null); + } + } + + /** + * The server's response to a resources/templates/list request from the client. + * + * @param resourceTemplates A list of resource templates that the server provides + * @param nextCursor An opaque token representing the pagination position after the + * last returned result. If present, there may be more results available + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListResourceTemplatesResult( // @formatter:off + @JsonProperty("resourceTemplates") List resourceTemplates, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListResourceTemplatesResult(List resourceTemplates, String nextCursor) { + this(resourceTemplates, nextCursor, null); + } + } + + /** + * Sent from the client to the server, to read a specific resource URI. + * + * @param uri The URI of the resource to read. The URI can use any protocol; it is up + * to the server how to interpret it + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ReadResourceRequest( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public ReadResourceRequest(String uri) { + this(uri, null); + } + } + + /** + * The server's response to a resources/read request from the client. + * + * @param contents The contents of the resource + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ReadResourceResult( // @formatter:off + @JsonProperty("contents") List contents, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ReadResourceResult(List contents) { + this(contents, null); + } + } + + /** + * Sent from the client to request resources/updated notifications from the server + * whenever a particular resource changes. + * + * @param uri the URI of the resource to subscribe to. The URI can use any protocol; + * it is up to the server how to interpret it. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SubscribeRequest( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public SubscribeRequest(String uri) { + this(uri, null); + } + } + + /** + * Sent from the client to request cancellation of resources/updated notifications + * from the server. This should follow a previous resources/subscribe request. + * + * @param uri The URI of the resource to unsubscribe from + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record UnsubscribeRequest( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public UnsubscribeRequest(String uri) { + this(uri, null); + } + } + + /** + * The contents of a specific resource or sub-resource. + */ + @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION) + @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class), + @JsonSubTypes.Type(value = BlobResourceContents.class) }) + public sealed interface ResourceContents extends Meta permits TextResourceContents, BlobResourceContents { + + /** + * The URI of this resource. + * @return the URI of this resource. + */ + String uri(); + + /** + * The MIME type of this resource. + * @return the MIME type of this resource. + */ + String mimeType(); + + } + + /** + * Text contents of a resource. + * + * @param uri the URI of this resource. + * @param mimeType the MIME type of this resource. + * @param text the text of the resource. This must only be set if the resource can + * actually be represented as text (not binary data). + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record TextResourceContents( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("text") String text, + @JsonProperty("_meta") Map meta) implements ResourceContents { // @formatter:on + + public TextResourceContents(String uri, String mimeType, String text) { + this(uri, mimeType, text, null); + } + } + + /** + * Binary contents of a resource. + * + * @param uri the URI of this resource. + * @param mimeType the MIME type of this resource. + * @param blob a base64-encoded string representing the binary data of the resource. + * This must only be set if the resource can actually be represented as binary data + * (not text). + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record BlobResourceContents( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("blob") String blob, + @JsonProperty("_meta") Map meta) implements ResourceContents { // @formatter:on + + public BlobResourceContents(String uri, String mimeType, String blob) { + this(uri, mimeType, blob, null); + } + } + + // --------------------------- + // Prompt Interfaces + // --------------------------- + /** + * A prompt or prompt template that the server offers. + * + * @param name The name of the prompt or prompt template. + * @param title An optional title for the prompt. + * @param description An optional description of what this prompt provides. + * @param arguments A list of arguments to use for templating the prompt. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Prompt( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("arguments") List arguments, + @JsonProperty("_meta") Map meta) implements Identifier { // @formatter:on + + public Prompt(String name, String description, List arguments) { + this(name, null, description, arguments != null ? arguments : new ArrayList<>()); + } + + public Prompt(String name, String title, String description, List arguments) { + this(name, title, description, arguments != null ? arguments : new ArrayList<>(), null); + } + } + + /** + * Describes an argument that a prompt can accept. + * + * @param name The name of the argument. + * @param title An optional title for the argument, which can be used in UI + * @param description A human-readable description of the argument. + * @param required Whether this argument must be provided. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptArgument( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("required") Boolean required) implements Identifier { // @formatter:on + + public PromptArgument(String name, String description, Boolean required) { + this(name, null, description, required); + } + } + + /** + * Describes a message returned as part of a prompt. + * + * This is similar to `SamplingMessage`, but also supports the embedding of resources + * from the MCP server. + * + * @param role The sender or recipient of messages and data in a conversation. + * @param content The content of the message of type {@link Content}. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptMessage( // @formatter:off + @JsonProperty("role") Role role, + @JsonProperty("content") Content content) { // @formatter:on + } + + /** + * The server's response to a prompts/list request from the client. + * + * @param prompts A list of prompts that the server provides. + * @param nextCursor An optional cursor for pagination. If present, indicates there + * are more prompts available. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListPromptsResult( // @formatter:off + @JsonProperty("prompts") List prompts, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListPromptsResult(List prompts, String nextCursor) { + this(prompts, nextCursor, null); + } + } + + /** + * Used by the client to get a prompt provided by the server. + * + * @param name The name of the prompt or prompt template. + * @param arguments Arguments to use for templating the prompt. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record GetPromptRequest( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public GetPromptRequest(String name, Map arguments) { + this(name, arguments, null); + } + } + + /** + * The server's response to a prompts/get request from the client. + * + * @param description An optional description for the prompt. + * @param messages A list of messages to display as part of the prompt. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record GetPromptResult( // @formatter:off + @JsonProperty("description") String description, + @JsonProperty("messages") List messages, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public GetPromptResult(String description, List messages) { + this(description, messages, null); + } + } + + // --------------------------- + // Tool Interfaces + // --------------------------- + /** + * The server's response to a tools/list request from the client. + * + * @param tools A list of tools that the server provides. + * @param nextCursor An optional cursor for pagination. If present, indicates there + * are more tools available. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListToolsResult( // @formatter:off + @JsonProperty("tools") List tools, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListToolsResult(List tools, String nextCursor) { + this(tools, nextCursor, null); + } + } + + /** + * A JSON Schema object that describes the expected structure of arguments or output. + * + * @param type The type of the schema (e.g., "object") + * @param properties The properties of the schema object + * @param required List of required property names + * @param additionalProperties Whether additional properties are allowed + * @param defs Schema definitions using the newer $defs keyword + * @param definitions Schema definitions using the legacy definitions keyword + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JsonSchema( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("properties") Map properties, + @JsonProperty("required") List required, + @JsonProperty("additionalProperties") Boolean additionalProperties, + @JsonProperty("$defs") Map defs, + @JsonProperty("definitions") Map definitions) { // @formatter:on + } + + /** + * Additional properties describing a Tool to clients. + * + * NOTE: all properties in ToolAnnotations are **hints**. They are not guaranteed to + * provide a faithful description of tool behavior (including descriptive properties + * like `title`). + * + * Clients should never make tool use decisions based on ToolAnnotations received from + * untrusted servers. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ToolAnnotations( // @formatter:off + @JsonProperty("title") String title, + @JsonProperty("readOnlyHint") Boolean readOnlyHint, + @JsonProperty("destructiveHint") Boolean destructiveHint, + @JsonProperty("idempotentHint") Boolean idempotentHint, + @JsonProperty("openWorldHint") Boolean openWorldHint, + @JsonProperty("returnDirect") Boolean returnDirect) { // @formatter:on + } + + /** + * Represents a tool that the server provides. Tools enable servers to expose + * executable functionality to the system. Through these tools, you can interact with + * external systems, perform computations, and take actions in the real world. + * + * @param name A unique identifier for the tool. This name is used when calling the + * tool. + * @param title A human-readable title for the tool. + * @param description A human-readable description of what the tool does. This can be + * used by clients to improve the LLM's understanding of available tools. + * @param inputSchema A JSON Schema object that describes the expected structure of + * the arguments when calling this tool. This allows clients to validate tool + * @param outputSchema An optional JSON Schema object defining the structure of the + * tool's output returned in the structuredContent field of a CallToolResult. + * @param annotations Optional additional tool information. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Tool( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("inputSchema") JsonSchema inputSchema, + @JsonProperty("outputSchema") Map outputSchema, + @JsonProperty("annotations") ToolAnnotations annotations, + @JsonProperty("_meta") Map meta) { // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private String title; + + private String description; + + private JsonSchema inputSchema; + + private Map outputSchema; + + private ToolAnnotations annotations; + + private Map meta; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(JsonSchema inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder inputSchema(McpJsonMapper jsonMapper, String inputSchema) { + this.inputSchema = parseSchema(jsonMapper, inputSchema); + return this; + } + + public Builder outputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + return this; + } + + public Builder outputSchema(McpJsonMapper jsonMapper, String outputSchema) { + this.outputSchema = schemaToMap(jsonMapper, outputSchema); + return this; + } + + public Builder annotations(ToolAnnotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Tool build() { + Assert.hasText(name, "name must not be empty"); + return new Tool(name, title, description, inputSchema, outputSchema, annotations, meta); + } + + } + } + + private static Map schemaToMap(McpJsonMapper jsonMapper, String schema) { + try { + return jsonMapper.readValue(schema, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid schema: " + schema, e); + } + } + + private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { + try { + return jsonMapper.readValue(schema, JsonSchema.class); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid schema: " + schema, e); + } + } + + /** + * Used by the client to call a tool provided by the server. + * + * @param name The name of the tool to call. This must match a tool name from + * tools/list. + * @param arguments Arguments to pass to the tool. These must conform to the tool's + * input schema. + * @param meta Optional metadata about the request. This can include additional + * information like `progressToken` + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CallToolRequest( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public CallToolRequest(McpJsonMapper jsonMapper, String name, String jsonArguments) { + this(name, parseJsonArguments(jsonMapper, jsonArguments), null); + } + + public CallToolRequest(String name, Map arguments) { + this(name, arguments, null); + } + + private static Map parseJsonArguments(McpJsonMapper jsonMapper, String jsonArguments) { + try { + return jsonMapper.readValue(jsonArguments, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private Map arguments; + + private Map meta; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder arguments(Map arguments) { + this.arguments = arguments; + return this; + } + + public Builder arguments(McpJsonMapper jsonMapper, String jsonArguments) { + this.arguments = parseJsonArguments(jsonMapper, jsonArguments); + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public CallToolRequest build() { + Assert.hasText(name, "name must not be empty"); + return new CallToolRequest(name, arguments, meta); + } + + } + } + + /** + * The server's response to a tools/call request from the client. + * + * @param content A list of content items representing the tool's output. Each item + * can be text, an image, or an embedded resource. + * @param isError If true, indicates that the tool execution failed and the content + * contains error information. If false or absent, indicates successful execution. + * @param structuredContent An optional JSON object that represents the structured + * result of the tool call. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CallToolResult( // @formatter:off + @JsonProperty("content") List content, + @JsonProperty("isError") Boolean isError, + @JsonProperty("structuredContent") Object structuredContent, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + /** + * @deprecated use the builder instead. + */ + @Deprecated + public CallToolResult(List content, Boolean isError) { + this(content, isError, (Object) null, null); + } + + /** + * @deprecated use the builder instead. + */ + @Deprecated + public CallToolResult(List content, Boolean isError, Map structuredContent) { + this(content, isError, structuredContent, null); + } + + /** + * Creates a new instance of {@link CallToolResult} with a string containing the + * tool result. + * @param content The content of the tool result. This will be mapped to a + * one-sized list with a {@link TextContent} element. + * @param isError If true, indicates that the tool execution failed and the + * content contains error information. If false or absent, indicates successful + * execution. + */ + @Deprecated + public CallToolResult(String content, Boolean isError) { + this(List.of(new TextContent(content)), isError, null); + } + + /** + * Creates a builder for {@link CallToolResult}. + * @return a new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link CallToolResult}. + */ + public static class Builder { + + private List content = new ArrayList<>(); + + private Boolean isError = false; + + private Object structuredContent; + + private Map meta; + + /** + * Sets the content list for the tool result. + * @param content the content list + * @return this builder + */ + public Builder content(List content) { + Assert.notNull(content, "content must not be null"); + this.content = content; + return this; + } + + public Builder structuredContent(Object structuredContent) { + Assert.notNull(structuredContent, "structuredContent must not be null"); + this.structuredContent = structuredContent; + return this; + } + + public Builder structuredContent(McpJsonMapper jsonMapper, String structuredContent) { + Assert.hasText(structuredContent, "structuredContent must not be empty"); + try { + this.structuredContent = jsonMapper.readValue(structuredContent, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid structured content: " + structuredContent, e); + } + return this; + } + + /** + * Sets the text content for the tool result. + * @param textContent the text content + * @return this builder + */ + public Builder textContent(List textContent) { + Assert.notNull(textContent, "textContent must not be null"); + textContent.stream().map(TextContent::new).forEach(this.content::add); + return this; + } + + /** + * Adds a content item to the tool result. + * @param contentItem the content item to add + * @return this builder + */ + public Builder addContent(Content contentItem) { + Assert.notNull(contentItem, "contentItem must not be null"); + if (this.content == null) { + this.content = new ArrayList<>(); + } + this.content.add(contentItem); + return this; + } + + /** + * Adds a text content item to the tool result. + * @param text the text content + * @return this builder + */ + public Builder addTextContent(String text) { + Assert.notNull(text, "text must not be null"); + return addContent(new TextContent(text)); + } + + /** + * Sets whether the tool execution resulted in an error. + * @param isError true if the tool execution failed, false otherwise + * @return this builder + */ + public Builder isError(Boolean isError) { + Assert.notNull(isError, "isError must not be null"); + this.isError = isError; + return this; + } + + /** + * Sets the metadata for the tool result. + * @param meta metadata + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link CallToolResult} instance. + * @return a new CallToolResult instance + */ + public CallToolResult build() { + return new CallToolResult(content, isError, structuredContent, meta); + } + + } + + } + + // --------------------------- + // Sampling Interfaces + // --------------------------- + /** + * The server's preferences for model selection, requested of the client during + * sampling. + * + * @param hints Optional hints to use for model selection. If multiple hints are + * specified, the client MUST evaluate them in order (such that the first match is + * taken). The client SHOULD prioritize these hints over the numeric priorities, but + * MAY still use the priorities to select from ambiguous matches + * @param costPriority How much to prioritize cost when selecting a model. A value of + * 0 means cost is not important, while a value of 1 means cost is the most important + * factor + * @param speedPriority How much to prioritize sampling speed (latency) when selecting + * a model. A value of 0 means speed is not important, while a value of 1 means speed + * is the most important factor + * @param intelligencePriority How much to prioritize intelligence and capabilities + * when selecting a model. A value of 0 means intelligence is not important, while a + * value of 1 means intelligence is the most important factor + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ModelPreferences( // @formatter:off + @JsonProperty("hints") List hints, + @JsonProperty("costPriority") Double costPriority, + @JsonProperty("speedPriority") Double speedPriority, + @JsonProperty("intelligencePriority") Double intelligencePriority) { // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List hints; + + private Double costPriority; + + private Double speedPriority; + + private Double intelligencePriority; + + public Builder hints(List hints) { + this.hints = hints; + return this; + } + + public Builder addHint(String name) { + if (this.hints == null) { + this.hints = new ArrayList<>(); + } + this.hints.add(new ModelHint(name)); + return this; + } + + public Builder costPriority(Double costPriority) { + this.costPriority = costPriority; + return this; + } + + public Builder speedPriority(Double speedPriority) { + this.speedPriority = speedPriority; + return this; + } + + public Builder intelligencePriority(Double intelligencePriority) { + this.intelligencePriority = intelligencePriority; + return this; + } + + public ModelPreferences build() { + return new ModelPreferences(hints, costPriority, speedPriority, intelligencePriority); + } + + } + } + + /** + * Hints to use for model selection. + * + * @param name A hint for a model name. The client SHOULD treat this as a substring of + * a model name; for example: `claude-3-5-sonnet` should match + * `claude-3-5-sonnet-20241022`, `sonnet` should match `claude-3-5-sonnet-20241022`, + * `claude-3-sonnet-20240229`, etc., `claude` should match any Claude model. The + * client MAY also map the string to a different provider's model name or a different + * model family, as long as it fills a similar niche + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ModelHint(@JsonProperty("name") String name) { + public static ModelHint of(String name) { + return new ModelHint(name); + } + } + + /** + * Describes a message issued to or received from an LLM API. + * + * @param role The sender or recipient of messages and data in a conversation + * @param content The content of the message + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SamplingMessage( // @formatter:off + @JsonProperty("role") Role role, + @JsonProperty("content") Content content) { // @formatter:on + } + + /** + * A request from the server to sample an LLM via the client. The client has full + * discretion over which model to select. The client should also inform the user + * before beginning sampling, to allow them to inspect the request (human in the loop) + * and decide whether to approve it. + * + * @param messages The conversation messages to send to the LLM + * @param modelPreferences The server's preferences for which model to select. The + * client MAY ignore these preferences + * @param systemPrompt An optional system prompt the server wants to use for sampling. + * The client MAY modify or omit this prompt + * @param includeContext A request to include context from one or more MCP servers + * (including the caller), to be attached to the prompt. The client MAY ignore this + * request + * @param temperature Optional temperature parameter for sampling + * @param maxTokens The maximum number of tokens to sample, as requested by the + * server. The client MAY choose to sample fewer tokens than requested + * @param stopSequences Optional stop sequences for sampling + * @param metadata Optional metadata to pass through to the LLM provider. The format + * of this metadata is provider-specific + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CreateMessageRequest( // @formatter:off + @JsonProperty("messages") List messages, + @JsonProperty("modelPreferences") ModelPreferences modelPreferences, + @JsonProperty("systemPrompt") String systemPrompt, + @JsonProperty("includeContext") ContextInclusionStrategy includeContext, + @JsonProperty("temperature") Double temperature, + @JsonProperty("maxTokens") Integer maxTokens, + @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("metadata") Map metadata, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + // backwards compatibility constructor + public CreateMessageRequest(List messages, ModelPreferences modelPreferences, + String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, Integer maxTokens, + List stopSequences, Map metadata) { + this(messages, modelPreferences, systemPrompt, includeContext, temperature, maxTokens, stopSequences, + metadata, null); + } + + public enum ContextInclusionStrategy { + + // @formatter:off + @JsonProperty("none") NONE, + @JsonProperty("thisServer") THIS_SERVER, + @JsonProperty("allServers")ALL_SERVERS + } // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List messages; + + private ModelPreferences modelPreferences; + + private String systemPrompt; + + private ContextInclusionStrategy includeContext; + + private Double temperature; + + private Integer maxTokens; + + private List stopSequences; + + private Map metadata; + + private Map meta; + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder modelPreferences(ModelPreferences modelPreferences) { + this.modelPreferences = modelPreferences; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public Builder includeContext(ContextInclusionStrategy includeContext) { + this.includeContext = includeContext; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public CreateMessageRequest build() { + return new CreateMessageRequest(messages, modelPreferences, systemPrompt, includeContext, temperature, + maxTokens, stopSequences, metadata, meta); + } + + } + } + + /** + * The client's response to a sampling/create_message request from the server. The + * client should inform the user before returning the sampled message, to allow them + * to inspect the response (human in the loop) and decide whether to allow the server + * to see it. + * + * @param role The role of the message sender (typically assistant) + * @param content The content of the sampled message + * @param model The name of the model that generated the message + * @param stopReason The reason why sampling stopped, if known + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CreateMessageResult( // @formatter:off + @JsonProperty("role") Role role, + @JsonProperty("content") Content content, + @JsonProperty("model") String model, + @JsonProperty("stopReason") StopReason stopReason, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public enum StopReason { + + // @formatter:off + @JsonProperty("endTurn") END_TURN("endTurn"), + @JsonProperty("stopSequence") STOP_SEQUENCE("stopSequence"), + @JsonProperty("maxTokens") MAX_TOKENS("maxTokens"), + @JsonProperty("unknown") UNKNOWN("unknown"); + // @formatter:on + + private final String value; + + StopReason(String value) { + this.value = value; + } + + @JsonCreator + private static StopReason of(String value) { + return Arrays.stream(StopReason.values()) + .filter(stopReason -> stopReason.value.equals(value)) + .findFirst() + .orElse(StopReason.UNKNOWN); + } + + } + + public CreateMessageResult(Role role, Content content, String model, StopReason stopReason) { + this(role, content, model, stopReason, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Role role = Role.ASSISTANT; + + private Content content; + + private String model; + + private StopReason stopReason = StopReason.END_TURN; + + private Map meta; + + public Builder role(Role role) { + this.role = role; + return this; + } + + public Builder content(Content content) { + this.content = content; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder stopReason(StopReason stopReason) { + this.stopReason = stopReason; + return this; + } + + public Builder message(String message) { + this.content = new TextContent(message); + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public CreateMessageResult build() { + return new CreateMessageResult(role, content, model, stopReason, meta); + } + + } + } + + // Elicitation + /** + * A request from the server to elicit additional information from the user via the + * client. + * + * @param message The message to present to the user + * @param requestedSchema A restricted subset of JSON Schema. Only top-level + * properties are allowed, without nesting + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest( // @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + // backwards compatibility constructor + public ElicitRequest(String message, Map requestedSchema) { + this(message, requestedSchema, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String message; + + private Map requestedSchema; + + private Map meta; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema, meta); + } + + } + } + + /** + * The client's response to an elicitation request. + * + * @param action The user action in response to the elicitation. "accept": User + * submitted the form/confirmed the action, "decline": User explicitly declined the + * action, "cancel": User dismissed without making an explicit choice + * @param content The submitted form data, only present when action is "accept". + * Contains values matching the requested schema + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult( // @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public enum Action { + + // @formatter:off + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } // @formatter:on + + // backwards compatibility constructor + public ElicitResult(Action action, Map content) { + this(action, content, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Action action; + + private Map content; + + private Map meta; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content, meta); + } + + } + } + + // --------------------------- + // Pagination Interfaces + // --------------------------- + /** + * A request that supports pagination using cursors. + * + * @param cursor An opaque token representing the current pagination position. If + * provided, the server should return results starting after this cursor + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PaginatedRequest( // @formatter:off + @JsonProperty("cursor") String cursor, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public PaginatedRequest(String cursor) { + this(cursor, null); + } + + /** + * Creates a new paginated request with an empty cursor. + */ + public PaginatedRequest() { + this(null); + } + } + + /** + * An opaque token representing the pagination position after the last returned + * result. If present, there may be more results available. + * + * @param nextCursor An opaque token representing the pagination position after the + * last returned result. If present, there may be more results available + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { + } + + // --------------------------- + // Progress and Logging + // --------------------------- + /** + * The Model Context Protocol (MCP) supports optional progress tracking for + * long-running operations through notification messages. Either side can send + * progress notifications to provide updates about operation status. + * + * @param progressToken A unique token to identify the progress notification. MUST be + * unique across all active requests. + * @param progress A value indicating the current progress. + * @param total An optional total amount of work to be done, if known. + * @param message An optional message providing additional context about the progress. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ProgressNotification( // @formatter:off + @JsonProperty("progressToken") Object progressToken, + @JsonProperty("progress") Double progress, + @JsonProperty("total") Double total, + @JsonProperty("message") String message, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + public ProgressNotification(Object progressToken, double progress, Double total, String message) { + this(progressToken, progress, total, message, null); + } + } + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to send + * resources update message to clients. + * + * @param uri The updated resource uri. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourcesUpdatedNotification(// @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + public ResourcesUpdatedNotification(String uri) { + this(uri, null); + } + } + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to send + * structured log messages to clients. Clients can control logging verbosity by + * setting minimum log levels, with servers sending notifications containing severity + * levels, optional logger names, and arbitrary JSON-serializable data. + * + * @param level The severity levels. The minimum log level is set by the client. + * @param logger The logger that generated the message. + * @param data JSON-serializable logging data. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record LoggingMessageNotification( // @formatter:off + @JsonProperty("level") LoggingLevel level, + @JsonProperty("logger") String logger, + @JsonProperty("data") String data, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + // backwards compatibility constructor + public LoggingMessageNotification(LoggingLevel level, String logger, String data) { + this(level, logger, data, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private LoggingLevel level = LoggingLevel.INFO; + + private String logger = "server"; + + private String data; + + private Map meta; + + public Builder level(LoggingLevel level) { + this.level = level; + return this; + } + + public Builder logger(String logger) { + this.logger = logger; + return this; + } + + public Builder data(String data) { + this.data = data; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public LoggingMessageNotification build() { + return new LoggingMessageNotification(level, logger, data, meta); + } + + } + } + + public enum LoggingLevel { + + // @formatter:off + @JsonProperty("debug") DEBUG(0), + @JsonProperty("info") INFO(1), + @JsonProperty("notice") NOTICE(2), + @JsonProperty("warning") WARNING(3), + @JsonProperty("error") ERROR(4), + @JsonProperty("critical") CRITICAL(5), + @JsonProperty("alert") ALERT(6), + @JsonProperty("emergency") EMERGENCY(7); + // @formatter:on + + private final int level; + + LoggingLevel(int level) { + this.level = level; + } + + public int level() { + return level; + } + + } + + /** + * A request from the client to the server, to enable or adjust logging. + * + * @param level The level of logging that the client wants to receive from the server. + * The server should send all logs at this level and higher (i.e., more severe) to the + * client as notifications/message + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { + } + + // --------------------------- + // Autocomplete + // --------------------------- + public sealed interface CompleteReference permits PromptReference, ResourceReference { + + String type(); + + String identifier(); + + } + + /** + * Identifies a prompt for completion requests. + * + * @param type The reference type identifier (typically "ref/prompt") + * @param name The name of the prompt + * @param title An optional title for the prompt + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptReference( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("name") String name, + @JsonProperty("title") String title ) implements McpSchema.CompleteReference, Identifier { // @formatter:on + + public static final String TYPE = "ref/prompt"; + + public PromptReference(String type, String name) { + this(type, name, null); + } + + public PromptReference(String name) { + this(TYPE, name, null); + } + + @Override + public String identifier() { + return name(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || getClass() != obj.getClass()) + return false; + PromptReference that = (PromptReference) obj; + return java.util.Objects.equals(identifier(), that.identifier()) + && java.util.Objects.equals(type(), that.type()); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(identifier(), type()); + } + } + + /** + * A reference to a resource or resource template definition for completion requests. + * + * @param type The reference type identifier (typically "ref/resource") + * @param uri The URI or URI template of the resource + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceReference( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { // @formatter:on + + public static final String TYPE = "ref/resource"; + + public ResourceReference(String uri) { + this(TYPE, uri); + } + + @Override + public String identifier() { + return uri(); + } + } + + /** + * A request from the client to the server, to ask for completion options. + * + * @param ref A reference to a prompt or resource template definition + * @param argument The argument's information for completion requests + * @param meta See specification for notes on _meta usage + * @param context Additional, optional context for completions + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteRequest( // @formatter:off + @JsonProperty("ref") McpSchema.CompleteReference ref, + @JsonProperty("argument") CompleteArgument argument, + @JsonProperty("_meta") Map meta, + @JsonProperty("context") CompleteContext context) implements Request { // @formatter:on + + public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument, Map meta) { + this(ref, argument, meta, null); + } + + public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument, CompleteContext context) { + this(ref, argument, null, context); + } + + public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument) { + this(ref, argument, null, null); + } + + /** + * The argument's information for completion requests. + * + * @param name The name of the argument + * @param value The value of the argument to use for completion matching + */ + public record CompleteArgument(@JsonProperty("name") String name, @JsonProperty("value") String value) { + } + + /** + * Additional, optional context for completions. + * + * @param arguments Previously-resolved variables in a URI template or prompt + */ + public record CompleteContext(@JsonProperty("arguments") Map arguments) { + } + } + + /** + * The server's response to a completion/complete request. + * + * @param completion The completion information containing values and metadata + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteResult(// @formatter:off + @JsonProperty("completion") CompleteCompletion completion, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + // backwards compatibility constructor + public CompleteResult(CompleteCompletion completion) { + this(completion, null); + } + + /** + * The server's response to a completion/complete request + * + * @param values An array of completion values. Must not exceed 100 items + * @param total The total number of completion options available. This can exceed + * the number of values actually sent in the response + * @param hasMore Indicates whether there are additional completion options beyond + * those provided in the current response, even if the exact total is unknown + */ + @JsonInclude(JsonInclude.Include.ALWAYS) + public record CompleteCompletion( // @formatter:off + @JsonProperty("values") List values, + @JsonProperty("total") Integer total, + @JsonProperty("hasMore") Boolean hasMore) { // @formatter:on + } + } + + // --------------------------- + // Content Types + // --------------------------- + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") + @JsonSubTypes({ @JsonSubTypes.Type(value = TextContent.class, name = "text"), + @JsonSubTypes.Type(value = ImageContent.class, name = "image"), + @JsonSubTypes.Type(value = AudioContent.class, name = "audio"), + @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource"), + @JsonSubTypes.Type(value = ResourceLink.class, name = "resource_link") }) + public sealed interface Content extends Meta + permits TextContent, ImageContent, AudioContent, EmbeddedResource, ResourceLink { + + default String type() { + if (this instanceof TextContent) { + return "text"; + } + else if (this instanceof ImageContent) { + return "image"; + } + else if (this instanceof AudioContent) { + return "audio"; + } + else if (this instanceof EmbeddedResource) { + return "resource"; + } + else if (this instanceof ResourceLink) { + return "resource_link"; + } + throw new IllegalArgumentException("Unknown content type: " + this); + } + + } + + /** + * Text provided to or from an LLM. + * + * @param annotations Optional annotations for the client + * @param text The text content of the message + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record TextContent( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("text") String text, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + public TextContent(Annotations annotations, String text) { + this(annotations, text, null); + } + + public TextContent(String content) { + this(null, content, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link TextContent#TextContent(Annotations, String)} instead. + */ + @Deprecated + public TextContent(List audience, Double priority, String content) { + this(audience != null || priority != null ? new Annotations(audience, priority) : null, content, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link TextContent#annotations()} instead. + */ + @Deprecated + public List audience() { + return annotations == null ? null : annotations.audience(); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link TextContent#annotations()} instead. + */ + @Deprecated + public Double priority() { + return annotations == null ? null : annotations.priority(); + } + } + + /** + * An image provided to or from an LLM. + * + * @param annotations Optional annotations for the client + * @param data The base64-encoded image data + * @param mimeType The MIME type of the image. Different providers may support + * different image types + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ImageContent( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("data") String data, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + public ImageContent(Annotations annotations, String data, String mimeType) { + this(annotations, data, mimeType, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link ImageContent#ImageContent(Annotations, String, String)} instead. + */ + @Deprecated + public ImageContent(List audience, Double priority, String data, String mimeType) { + this(audience != null || priority != null ? new Annotations(audience, priority) : null, data, mimeType, + null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link ImageContent#annotations()} instead. + */ + @Deprecated + public List audience() { + return annotations == null ? null : annotations.audience(); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link ImageContent#annotations()} instead. + */ + @Deprecated + public Double priority() { + return annotations == null ? null : annotations.priority(); + } + } + + /** + * Audio provided to or from an LLM. + * + * @param annotations Optional annotations for the client + * @param data The base64-encoded audio data + * @param mimeType The MIME type of the audio. Different providers may support + * different audio types + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record AudioContent( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("data") String data, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + // backwards compatibility constructor + public AudioContent(Annotations annotations, String data, String mimeType) { + this(annotations, data, mimeType, null); + } + } + + /** + * The contents of a resource, embedded into a prompt or tool call result. + * + * It is up to the client how best to render embedded resources for the benefit of the + * LLM and/or the user. + * + * @param annotations Optional annotations for the client + * @param resource The resource contents that are embedded + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record EmbeddedResource( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("resource") ResourceContents resource, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + // backwards compatibility constructor + public EmbeddedResource(Annotations annotations, ResourceContents resource) { + this(annotations, resource, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link EmbeddedResource#EmbeddedResource(Annotations, ResourceContents)} + * instead. + */ + @Deprecated + public EmbeddedResource(List audience, Double priority, ResourceContents resource) { + this(audience != null || priority != null ? new Annotations(audience, priority) : null, resource, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link EmbeddedResource#annotations()} instead. + */ + @Deprecated + public List audience() { + return annotations == null ? null : annotations.audience(); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link EmbeddedResource#annotations()} instead. + */ + @Deprecated + public Double priority() { + return annotations == null ? null : annotations.priority(); + } + } + + /** + * A known resource that the server is capable of reading. + * + * @param uri the URI of the resource. + * @param name A human-readable name for this resource. This can be used by clients to + * populate UI elements. + * @param title A human-readable title for this resource. + * @param description A description of what this resource represents. This can be used + * by clients to improve the LLM's understanding of available resources. It can be + * thought of like a "hint" to the model. + * @param mimeType The MIME type of this resource, if known. + * @param size The size of the raw resource content, in bytes (i.e., before base64 + * encoding or any tokenization), if known. This can be used by Hosts to display file + * sizes and estimate context window usage. + * @param annotations Optional annotations for the client. The client can use + * annotations to inform how objects are used or displayed. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceLink( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("uri") String uri, + @JsonProperty("description") String description, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("size") Long size, + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("_meta") Map meta) implements Content, ResourceContent { // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private String title; + + private String uri; + + private String description; + + private String mimeType; + + private Annotations annotations; + + private Long size; + + private Map meta; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder size(Long size) { + this.size = size; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ResourceLink build() { + Assert.hasText(uri, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new ResourceLink(name, title, uri, description, mimeType, size, annotations, meta); + } + + } + } + + // --------------------------- + // Roots + // --------------------------- + /** + * Represents a root directory or file that the server can operate on. + * + * @param uri The URI identifying the root. This *must* start with file:// for now. + * This restriction may be relaxed in future versions of the protocol to allow other + * URI schemes. + * @param name An optional name for the root. This can be used to provide a + * human-readable identifier for the root, which may be useful for display purposes or + * for referencing the root in other parts of the application. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Root( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("name") String name, + @JsonProperty("_meta") Map meta) { // @formatter:on + + public Root(String uri, String name) { + this(uri, name, null); + } + } + + /** + * The client's response to a roots/list request from the server. This result contains + * an array of Root objects, each representing a root directory or file that the + * server can operate on. + * + * @param roots An array of Root objects, each representing a root directory or file + * that the server can operate on. + * @param nextCursor An optional cursor for pagination. If present, indicates there + * are more roots available. The client can use this cursor to request the next page + * of results by sending a roots/list request with the cursor parameter set to this + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListRootsResult( // @formatter:off + @JsonProperty("roots") List roots, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListRootsResult(List roots) { + this(roots, null); + } + + public ListRootsResult(List roots, String nextCursor) { + this(roots, nextCursor, null); + } + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java new file mode 100644 index 000000000..241f7d8b5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -0,0 +1,444 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +/** + * Represents a Model Context Protocol (MCP) session on the server side. It manages + * bidirectional JSON-RPC communication with the client. + */ +public class McpServerSession implements McpLoggableSession { + + private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final String id; + + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final McpInitRequestHandler initRequestHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final McpServerTransport transport; + + private final Sinks.One exchangeSink = Sinks.one(); + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + McpInitRequestHandler initHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.transport = transport; + this.initRequestHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param initNotificationHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema#METHOD_NOTIFICATION_INITIALIZED} is + * received. + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + * @deprecated Use + * {@link #McpServerSession(String, Duration, McpServerTransport, McpInitRequestHandler, Map, Map)} + */ + @Deprecated + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.transport = transport; + this.initRequestHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + /** + * Retrieve the session id. + * @return session id + */ + public String getId() { + return this.id; + } + + /** + * Called upon successful initialization sequence between the client and the server + * with the client capabilities and information. + * + * Initialization + * Spec + * @param clientCapabilities the capabilities the connected client provides + * @param clientInfo the information about the connected client + */ + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + // TODO handle errors for communication to without initialization happening + // first + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + } + else { + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request, transportContext).onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + jsonRpcError); + // TODO: Should the error go to SSE or back as POST return? + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // TODO handle errors for communication to without initialization + // happening first + logger.debug("Received notification: {}", notification); + // TODO: in case of error, should the POST request be signalled? + return handleIncomingNotification(notification, transportContext) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + } + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @param transportContext + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpTransportContext transportContext) { + return Mono.defer(() -> { + Mono resultMono; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + // TODO handle situation where already initialized! + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), + new TypeRef() { + }); + + this.state.lazySet(STATE_INITIALIZING); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + resultMono = this.initRequestHandler.handle(initializeRequest); + } + else { + // TODO handle errors for communication to this session without + // initialization happening first + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(copyExchange(exchange, transportContext), request.params())); + } + return resultMono + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + return Mono.just( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, jsonRpcError)); + }); + }); + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @param transportContext + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, + McpTransportContext transportContext) { + return Mono.defer(() -> { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + this.state.lazySet(STATE_INITIALIZED); + // FIXME: The session ID passed here is not the same as the one in the + // legacy SSE transport. + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), + clientInfo.get(), transportContext)); + } + + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.warn("No handler registered for notification method: {}", notification); + return Mono.empty(); + } + return this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(copyExchange(exchange, transportContext), notification.params())); + }); + } + + /** + * This legacy implementation assumes an exchange is established upon the + * initialization phase see: exchangeSink.tryEmitValue(...), which creates a cached + * immutable exchange. Here, we create a new exchange and copy over everything from + * that cached exchange, and use it for a single HTTP request, with the transport + * context passed in. + */ + private McpAsyncServerExchange copyExchange(McpAsyncServerExchange exchange, McpTransportContext transportContext) { + return new McpAsyncServerExchange(exchange.sessionId(), this, exchange.getClientCapabilities(), + exchange.getClientInfo(), transportContext); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + // TODO: clear pendingResponses and emit errors? + return this.transport.closeGracefully(); + } + + @Override + public void close() { + // TODO: clear pendingResponses and emit errors? + this.transport.close(); + } + + /** + * Request handler for the initialization request. + * + * @deprecated Use {@link McpInitRequestHandler} + */ + @Deprecated + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + public interface InitNotificationHandler { + + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ + Mono handle(); + + } + + /** + * A handler for client-initiated notifications. + * + * @deprecated Use {@link McpNotificationHandler} + */ + @Deprecated + public interface NotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + * @deprecated Use {@link McpRequestHandler} + */ + @Deprecated + public interface RequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + public interface Factory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ + McpServerSession create(McpServerTransport sessionTransport); + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java new file mode 100644 index 000000000..39c1644e0 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -0,0 +1,15 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransport extends McpTransport { + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java new file mode 100644 index 000000000..02028ccdf --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -0,0 +1,23 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Classic implementation of {@link McpServerTransportProviderBase} for a single outgoing + * stream in bidirectional communication (STDIO and the legacy HTTP SSE). + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProvider extends McpServerTransportProviderBase { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java new file mode 100644 index 000000000..acb1ecac6 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java @@ -0,0 +1,71 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

    + * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

    + * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProviderBase { + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + + /** + * Returns the protocol version supported by this transport provider. + * @return the protocol version as a string + */ + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java similarity index 70% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 92b460755..767ed673e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -4,13 +4,11 @@ package io.modelcontextprotocol.spec; -import java.util.Map; - -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** - * Represents a Model Control Protocol (MCP) session that handles communication between + * Represents a Model Context Protocol (MCP) session that handles communication between * clients and the server. This interface provides methods for sending requests and * notifications, as well as managing the session lifecycle. * @@ -26,26 +24,27 @@ public interface McpSession { /** - * Sends a request to the model server and expects a response of type T. + * Sends a request to the model counterparty and expects a response of type T. * *

    * This method handles the request-response pattern where a response is expected from - * the server. The response type is determined by the provided TypeReference. + * the client or server. The response type is determined by the provided + * TypeReference. *

    * @param the type of the expected response - * @param method the name of the method to be called on the server + * @param method the name of the method to be called on the counterparty * @param requestParams the parameters to be sent with the request * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received */ - Mono sendRequest(String method, Object requestParams, TypeReference typeRef); + Mono sendRequest(String method, Object requestParams, TypeRef typeRef); /** - * Sends a notification to the model server without parameters. + * Sends a notification to the model client or server without parameters. * *

    * This method implements the notification pattern where no response is expected from - * the server. It's useful for fire-and-forget scenarios. + * the counterparty. It's useful for fire-and-forget scenarios. *

    * @param method the name of the notification method to be called on the server * @return a Mono that completes when the notification has been sent @@ -55,17 +54,17 @@ default Mono sendNotification(String method) { } /** - * Sends a notification to the model server with parameters. + * Sends a notification to the model client or server with parameters. * *

    * Similar to {@link #sendNotification(String)} but allows sending additional * parameters with the notification. *

    - * @param method the name of the notification method to be called on the server - * @param params a map of parameters to be sent with the notification + * @param method the name of the notification method to be sent to the counterparty + * @param params parameters to be sent with the notification * @return a Mono that completes when the notification has been sent */ - Mono sendNotification(String method, Map params); + Mono sendNotification(String method, Object params); /** * Closes the session and releases any associated resources asynchronously. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java new file mode 100644 index 000000000..d1c2e5206 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.List; + +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import reactor.core.publisher.Mono; + +public interface McpStatelessServerTransport { + + void setMcpHandler(McpStatelessServerHandler mcpHandler); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java new file mode 100644 index 000000000..95f8959f5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -0,0 +1,420 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +/** + * Representation of a Streamable HTTP server session that keeps track of mapping + * server-initiated requests to the client and mapping arriving responses. It also allows + * handling incoming notifications. For requests, it provides the default SSE streaming + * capability without the insight into the transport-specific details of HTTP handling. + * + * @author Dariusz Jędrzejczyk + * @author Yanming Zhou + */ +public class McpStreamableServerSession implements McpLoggableSession { + + private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class); + + private final ConcurrentHashMap requestIdToStream = new ConcurrentHashMap<>(); + + private final String id; + + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private final AtomicReference listeningStreamRef; + + private final MissingMcpTransportSession missingMcpTransportSession; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Create an instance of the streamable session. + * @param id session ID + * @param clientCapabilities client capabilities + * @param clientInfo client info + * @param requestTimeout timeout to use for requests + * @param requestHandlers the map of MCP request handlers keyed by method name + * @param notificationHandlers the map of MCP notification handlers keyed by method + * name + */ + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, Duration requestTimeout, + Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.missingMcpTransportSession = new MissingMcpTransportSession(id); + this.listeningStreamRef = new AtomicReference<>(this.missingMcpTransportSession); + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + this.requestTimeout = requestTimeout; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + + /** + * Return the Session ID. + * @return session ID + */ + public String getId() { + return this.id; + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendRequest(method, requestParams, typeRef); + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendNotification(method, params); + }); + } + + public Mono delete() { + return this.closeGracefully().then(Mono.fromRunnable(() -> { + // TODO: review in the context of history storage + // delete history, etc. + })); + } + + /** + * Create a listening stream (the generic HTTP GET request without Last-Event-ID + * header). + * @param transport The dedicated SSE transport stream + * @return a stream representation + */ + public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); + this.listeningStreamRef.set(listeningStream); + return listeningStream; + } + + // TODO: keep track of history by keeping a map from eventId to stream and then + // iterate over the events using the lastEventId + public Flux replay(Object lastEventId) { + return Flux.empty(); + } + + /** + * Provide the SSE stream of MCP messages finalized with a Response. + * @param jsonrpcRequest the MCP request triggering the stream creation + * @param transport the SSE transport stream to send messages to + * @return Mono which completes once the processing is done + */ + public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers + .get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close + // remove itself from the registry and also close the underlying transport + // (sink) + if (requestHandler == null) { + MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); + return transport + .sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + return requestHandler + .handle(new McpAsyncServerExchange(this.id, stream, clientCapabilities.get(), clientInfo.get(), + transportContext), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, + null)) + .onErrorResume(e -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (e instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), McpError.aggregateExceptionMessages(e)); + + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), + null, jsonRpcError); + return Mono.just(errorResponse); + }) + .flatMap(transport::sendMessage) + .then(transport.closeGracefully()); + }); + } + + /** + * Handle the MCP notification. + * @param notification MCP notification + * @return Mono which completes upon succesful handling + */ + public Mono accept(McpSchema.JSONRPCNotification notification) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.warn("No handler registered for notification method: {}", notification); + return Mono.empty(); + } + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, + this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); + }); + + } + + /** + * Handle the MCP response. + * @param response MCP response to the server-initiated request + * @return Mono which completes upon successful processing + */ + public Mono accept(McpSchema.JSONRPCResponse response) { + return Mono.defer(() -> { + logger.debug("Received response: {}", response); + + if (response.id() != null) { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + // TODO: encapsulate this inside the stream itself + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + else { + sink.success(response); + } + } + else { + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); + } + return Mono.empty(); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + return listeningStream.closeGracefully(); + // TODO: Also close all the open streams + }); + } + + @Override + public void close() { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + if (listeningStream != null) { + listeningStream.close(); + } + // TODO: Also close all open streams + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Factory for new Streamable HTTP MCP sessions. + */ + public interface Factory { + + /** + * Given an initialize request, create a composite for the session initialization + * @param initializeRequest the initialization request from the client + * @return a composite allowing the session to start + */ + McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Composite holding the {@link McpStreamableServerSession} and the initialization + * result + * + * @param session the session instance + * @param initResult the result to use to respond to the client + */ + public record McpStreamableServerSessionInit(McpStreamableServerSession session, + Mono initResult) { + } + + /** + * An individual SSE stream within a Streamable HTTP context. Can be either the + * listening GET SSE stream or a request-specific POST SSE stream. + */ + public final class McpStreamableServerSessionStream implements McpLoggableSession { + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final McpStreamableServerTransport transport; + + private final String transportId; + + private final Supplier uuidGenerator; + + /** + * Constructor accepting the dedicated transport representing the SSE stream. + * @param transport request-specific SSE transport stream + */ + public McpStreamableServerSessionStream(McpStreamableServerTransport transport) { + this.transport = transport; + this.transportId = UUID.randomUUID().toString(); + // This ID design allows for a constant-time extraction of the history by + // precisely identifying the SSE stream using the first component + this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID(); + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel); + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + String requestId = McpStreamableServerSession.this.generateRequestId(); + + McpStreamableServerSession.this.requestIdToStream.put(requestId, this); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + method, requestId, requestParams); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> { + }, sink::error); + }).timeout(requestTimeout).doOnError(e -> { + this.pendingResponses.remove(requestId); + McpStreamableServerSession.this.requestIdToStream.remove(requestId); + }).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, method, params); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + return this.transport.sendMessage(jsonrpcNotification, messageId); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + return this.transport.closeGracefully(); + }); + } + + @Override + public void close() { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + this.transport.close(); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java new file mode 100644 index 000000000..f53c68900 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * Streamable HTTP server transport representing an individual SSE stream. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransport extends McpServerTransport { + + /** + * Send a message to the client with a message ID for use in the SSE event payload + * @param message the JSON-RPC payload + * @param messageId message id for SSE events + * @return Mono which completes when done + */ + Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java new file mode 100644 index 000000000..09fe9fb0e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport for Streamable HTTP + * servers. Implement this interface to bridge between a particular server-side technology + * and the MCP server transport layer. + * + *

    + * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} + * or + * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. + * As a result of the MCP server creation, the provider will be notified of a + * {@link McpStreamableServerSession.Factory} which will be used to handle a 1:1 + * communication between a newly connected client and the server using a session concept. + * The provider's responsibility is to create instances of + * {@link McpStreamableServerTransport} that the session will utilise during the session + * lifetime. + * + *

    + * Finally, the {@link McpStreamableServerTransport}s can be closed in bulk when + * {@link #close()} or {@link #closeGracefully()} are called as part of the normal + * application shutdown event. Individual {@link McpStreamableServerTransport}s can also + * be closed on a per-session basis, where the {@link McpServerSession#close()} or + * {@link McpServerSession#closeGracefully()} closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransportProvider extends McpServerTransportProviderBase { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpStreamableServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 344a50bfe..0a732bab6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -4,10 +4,10 @@ package io.modelcontextprotocol.spec; -import java.util.function.Function; +import java.util.List; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** @@ -39,16 +39,6 @@ */ public interface McpTransport { - /** - * Initializes and starts the transport connection. - * - *

    - * This method should be called before any message exchange can occur. It sets up the - * necessary resources and establishes the connection to the server. - *

    - */ - Mono connect(Function, Mono> handler); - /** * Closes the transport connection and releases any associated resources. * @@ -69,7 +59,7 @@ default void close() { Mono closeGracefully(); /** - * Sends a message to the server asynchronously. + * Sends a message to the peer asynchronously. * *

    * This method handles the transmission of messages to the server in an asynchronous @@ -87,6 +77,10 @@ default void close() { * @param typeRef the type reference for the object to unmarshal * @return the unmarshalled object */ - T unmarshalFrom(Object data, TypeReference typeRef); + T unmarshalFrom(Object data, TypeRef typeRef); + + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java new file mode 100644 index 000000000..cfd3dae31 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java @@ -0,0 +1,38 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +/** + * Exception thrown when there is an issue with the transport layer of the Model Context + * Protocol (MCP). + * + *

    + * This exception is used to indicate errors that occur during communication between the + * MCP client and server, such as connection failures, protocol violations, or unexpected + * responses. + * + * @author Christian Tzolov + */ +public class McpTransportException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public McpTransportException(String message) { + super(message); + } + + public McpTransportException(String message, Throwable cause) { + super(message, cause); + } + + public McpTransportException(Throwable cause) { + super(cause); + } + + public McpTransportException(String message, Throwable cause, boolean enableSuppression, + boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } + +} \ No newline at end of file diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java new file mode 100644 index 000000000..68f0fc5bb --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.Optional; + +import org.reactivestreams.Publisher; + +/** + * An abstraction of the session as perceived from the MCP transport layer. Not to be + * confused with the {@link McpSession} type that operates at the level of the JSON-RPC + * communication protocol and matches asynchronous responses with previously issued + * requests. + * + * @param the resource representing the connection that the transport + * manages. + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportSession { + + /** + * In case of stateful MCP servers, the value is present and contains the String + * identifier for the transport-level session. + * @return optional session id + */ + Optional sessionId(); + + /** + * Stateful operation that flips the un-initialized state to initialized if this is + * the first call. If the transport provides a session id for the communication, + * argument should not be null to record the current identifier. + * @param sessionId session identifier as provided by the server + * @return if successful, this method returns {@code true} and means that a + * post-initialization step can be performed + */ + boolean markInitialized(String sessionId); + + /** + * Adds a resource that this transport session can monitor and dismiss when needed. + * @param connection the managed resource + */ + void addConnection(CONNECTION connection); + + /** + * Called when the resource is terminating by itself and the transport session does + * not need to track it anymore. + * @param connection the resource to remove from the monitored collection + */ + void removeConnection(CONNECTION connection); + + /** + * Close and clear the monitored resources. Potentially asynchronous. + */ + void close(); + + /** + * Close and clear the monitored resources in a graceful manner. + * @return completes once all resources have been dismissed + */ + Publisher closeGracefully(); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java new file mode 100644 index 000000000..60e2850b9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java @@ -0,0 +1,23 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.util.annotation.Nullable; + +/** + * Exception thrown when trying to use an {@link McpTransportSession} that has been + * closed. + * + * @see ClosedMcpTransportSession + * @author Daniel Garnier-Moiroux + */ +public class McpTransportSessionClosedException extends RuntimeException { + + public McpTransportSessionClosedException(@Nullable String sessionId) { + super(sessionId != null ? "MCP session with ID %s has been closed".formatted(sessionId) + : "MCP session has been closed"); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java new file mode 100644 index 000000000..eced49ec3 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Exception that signifies that the server does not recognize the connecting client via + * the presented transport session identifier. + * + * @author Dariusz Jędrzejczyk + */ +public class McpTransportSessionNotFoundException extends RuntimeException { + + /** + * Construct an instance with a known {@link Exception cause}. + * @param sessionId transport session identifier + * @param cause the cause that was identified as a session not found error + */ + public McpTransportSessionNotFoundException(String sessionId, Exception cause) { + super("Session " + sessionId + " not found on the server", cause); + } + + /** + * Construct an instance with the session identifier but without a {@link Exception + * cause}. + * @param sessionId transport session identifier + */ + public McpTransportSessionNotFoundException(String sessionId) { + super("Session " + sessionId + " not found on the server"); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java new file mode 100644 index 000000000..322afda63 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import reactor.util.function.Tuple2; + +import java.util.Optional; + +/** + * A representation of a stream at the transport layer of the MCP protocol. In particular, + * it is currently used in the Streamable HTTP implementation to potentially be able to + * resume a broken connection from where it left off by optionally keeping track of + * attached SSE event ids. + * + * @param the resource on which the stream is being served and consumed via + * this mechanism + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportStream { + + /** + * The last observed event identifier. + * @return if not empty, contains the most recent event that was consumed + */ + Optional lastId(); + + /** + * An internal stream identifier used to distinguish streams while debugging. + * @return a {@code long} stream identifier value + */ + long streamId(); + + /** + * Allows keeping track of the transport stream of events (currently an SSE stream + * from Streamable HTTP specification) and enable resumability and reconnects in case + * of stream errors. + * @param eventStream a {@link Publisher} of tuples (pairs) of an optional identifier + * associated with a collection of messages + * @return a flattened {@link Publisher} of + * {@link io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage JSON-RPC messages} + * with the identifier stripped away + */ + Publisher consumeSseStream( + Publisher, Iterable>> eventStream); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java new file mode 100644 index 000000000..0bf70d5b8 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * A {@link McpLoggableSession} which represents a missing stream that would allow the + * server to communicate with the client. Specifically, it can be used when a Streamable + * HTTP client has not opened a listening SSE stream to accept messages for interactions + * unrelated with concurrently running client-initiated requests. + * + * @author Dariusz Jędrzejczyk + */ +public class MissingMcpTransportSession implements McpLoggableSession { + + private final String sessionId; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Create an instance with the Session ID specified. + * @param sessionId session ID + */ + public MissingMcpTransportSession(String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java new file mode 100644 index 000000000..d3d34db62 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.spec; + +public interface ProtocolVersions { + + /** + * MCP protocol version for 2024-11-05. + * https://modelcontextprotocol.io/specification/2024-11-05 + */ + String MCP_2024_11_05 = "2024-11-05"; + + /** + * MCP protocol version for 2025-03-26. + * https://modelcontextprotocol.io/specification/2025-03-26 + */ + String MCP_2025_03_26 = "2025-03-26"; + + /** + * MCP protocol version for 2025-06-18. + * https://modelcontextprotocol.io/specification/2025-06-18 + */ + String MCP_2025_06_18 = "2025-06-18"; + + /** + * MCP protocol version for 2025-11-25. + * https://modelcontextprotocol.io/specification/2025-11-25 + */ + String MCP_2025_11_25 = "2025-11-25"; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java similarity index 84% rename from mcp/src/main/java/io/modelcontextprotocol/util/Assert.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java index d68188c6f..1fa6b3058 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java @@ -76,4 +76,17 @@ public static boolean hasText(@Nullable String str) { return (str != null && !str.isBlank()); } + /** + * Assert a boolean expression, throwing an {@code IllegalArgumentException} if the + * expression evaluates to {@code false}. + * @param expression a boolean expression + * @param message the exception message to use if the assertion fails + * @throws IllegalArgumentException if {@code expression} is {@code false} + */ + public static void isTrue(boolean expression, String message) { + if (!expression) { + throw new IllegalArgumentException(message); + } + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java new file mode 100644 index 000000000..c3b922edf --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -0,0 +1,175 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Default implementation of the UriTemplateUtils interface. + *

    + * This class provides methods for extracting variables from URI templates and matching + * them against actual URIs. + * + * @author Christian Tzolov + */ +public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { + + /** + * Pattern to match URI variables in the format {variableName}. + */ + private static final Pattern URI_VARIABLE_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); + + private final String uriTemplate; + + /** + * Constructor for DefaultMcpUriTemplateManager. + * @param uriTemplate The URI template to be used for variable extraction + */ + public DefaultMcpUriTemplateManager(String uriTemplate) { + Assert.hasText(uriTemplate, "URI template must not be null or empty"); + this.uriTemplate = uriTemplate; + } + + /** + * Extract URI variable names from a URI template. + * @param uriTemplate The URI template containing variables in the format + * {variableName} + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + @Override + public List getVariableNames() { + List variables = new ArrayList<>(); + Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + + while (matcher.find()) { + String variableName = matcher.group(1); + if (variables.contains(variableName)) { + throw new IllegalArgumentException("Duplicate URI variable name in template: " + variableName); + } + variables.add(variableName); + } + + return variables; + } + + /** + * Extract URI variable values from the actual request URI. + *

    + * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param requestUri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + @Override + public Map extractVariableValues(String requestUri) { + Map variableValues = new HashMap<>(); + List uriVariables = this.getVariableNames(); + + if (!Utils.hasText(requestUri) || uriVariables.isEmpty()) { + return variableValues; + } + + try { + // Create a regex pattern by replacing each {variableName} with a capturing + // group + StringBuilder patternBuilder = new StringBuilder("^"); + + // Find all variable placeholders and their positions + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Add the text between the last variable and this one, escaped for regex + String textBefore = uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + + // Add a capturing group for the variable + patternBuilder.append("([^/]+)"); + + lastEnd = variableMatcher.end(); + } + + // Add any remaining text after the last variable + if (lastEnd < uriTemplate.length()) { + patternBuilder.append(Pattern.quote(uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); + + // Compile the pattern and match against the request URI + Pattern pattern = Pattern.compile(patternBuilder.toString()); + Matcher matcher = pattern.matcher(requestUri); + + if (matcher.find() && matcher.groupCount() == uriVariables.size()) { + for (int i = 0; i < uriVariables.size(); i++) { + String value = matcher.group(i + 1); + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException( + "Empty value for URI variable '" + uriVariables.get(i) + "' in URI: " + requestUri); + } + variableValues.put(uriVariables.get(i), value); + } + } + } + catch (Exception e) { + throw new IllegalArgumentException("Error parsing URI template: " + uriTemplate + " for URI: " + requestUri, + e); + } + + return variableValues; + } + + /** + * Check if a URI matches the uriTemplate with variables. + * @param uri The URI to check + * @return true if the URI matches the pattern, false otherwise + */ + @Override + public boolean matches(String uri) { + // If the uriTemplate doesn't contain variables, do a direct comparison + if (!this.isUriTemplate(this.uriTemplate)) { + return uri.equals(this.uriTemplate); + } + + // Convert the URI template into a robust regex pattern that escapes special + // characters like '?'. + StringBuilder patternBuilder = new StringBuilder("^"); + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Append the literal part of the template, safely quoted + String textBefore = this.uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + // Append a capturing group for the variable itself + patternBuilder.append("([^/]+?)"); + lastEnd = variableMatcher.end(); + } + + // Append any remaining literal text after the last variable + if (lastEnd < this.uriTemplate.length()) { + patternBuilder.append(Pattern.quote(this.uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); + + // Check if the URI matches the regex + return Pattern.compile(patternBuilder.toString()).matcher(uri).matches(); + } + + @Override + public boolean isUriTemplate(String uri) { + return URI_VARIABLE_PATTERN.matcher(uri).find(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java new file mode 100644 index 000000000..fd1a3bd71 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java @@ -0,0 +1,24 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.util; + +/** + * @author Christian Tzolov + */ +public class DefaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + @Override + public McpUriTemplateManager create(String uriTemplate) { + return new DefaultMcpUriTemplateManager(uriTemplate); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java new file mode 100644 index 000000000..6d53ed516 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -0,0 +1,217 @@ +/** + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * A utility class for scheduling regular keep-alive calls to maintain connections. It + * sends periodic keep-alive, ping, messages to connected mcp clients to prevent idle + * timeouts. + * + * The pings are sent to all active mcp sessions at regular intervals. + * + * @author Christian Tzolov + */ +public class KeepAliveScheduler { + + private static final Logger logger = LoggerFactory.getLogger(KeepAliveScheduler.class); + + private static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { + }; + + /** Initial delay before the first keepAlive call */ + private final Duration initialDelay; + + /** Interval between subsequent keepAlive calls */ + private final Duration interval; + + /** The scheduler used for executing keepAlive calls */ + private final Scheduler scheduler; + + /** The current state of the scheduler */ + private final AtomicBoolean isRunning = new AtomicBoolean(false); + + /** The current subscription for the keepAlive calls */ + private Disposable currentSubscription; + + // TODO Currently we do not support the streams (streamable http session created by + // http post/get) + + /** Supplier for reactive McpSession instances */ + private final Supplier> mcpSessions; + + /** + * Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a + * supplier for McpSession instances. + * @param scheduler The scheduler to use for executing keepAlive calls + * @param initialDelay Initial delay before the first keepAlive call + * @param interval Interval between subsequent keepAlive calls + * @param mcpSessions Supplier for McpSession instances + */ + KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval, + Supplier> mcpSessions) { + this.scheduler = scheduler; + this.initialDelay = initialDelay; + this.interval = interval; + this.mcpSessions = mcpSessions; + } + + /** + * Creates a new Builder instance for constructing KeepAliveScheduler. + * @return A new Builder instance + */ + public static Builder builder(Supplier> mcpSessions) { + return new Builder(mcpSessions); + } + + /** + * Starts regular keepAlive calls with sessions supplier. + * @return Disposable to control the scheduled execution + */ + public Disposable start() { + if (this.isRunning.compareAndSet(false, true)) { + + this.currentSubscription = Flux.interval(this.initialDelay, this.interval, this.scheduler) + .doOnNext(tick -> { + this.mcpSessions.get() + .flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF) + .doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session, + e.getMessage())) + .onErrorComplete()) + .subscribe(); + }) + .doOnCancel(() -> this.isRunning.set(false)) + .doOnComplete(() -> this.isRunning.set(false)) + .onErrorComplete(error -> { + logger.error("KeepAlive scheduler error", error); + this.isRunning.set(false); + return true; + }) + .subscribe(); + + return this.currentSubscription; + } + else { + throw new IllegalStateException("KeepAlive scheduler is already running. Stop it first."); + } + } + + /** + * Stops the currently running keepAlive scheduler. + */ + public void stop() { + if (this.currentSubscription != null && !this.currentSubscription.isDisposed()) { + this.currentSubscription.dispose(); + } + this.isRunning.set(false); + } + + /** + * Checks if the scheduler is currently running. + * @return true if running, false otherwise + */ + public boolean isRunning() { + return this.isRunning.get(); + } + + /** + * Shuts down the scheduler and releases resources. + */ + public void shutdown() { + stop(); + if (this.scheduler instanceof Disposable) { + ((Disposable) this.scheduler).dispose(); + } + } + + /** + * Builder class for creating KeepAliveScheduler instances with fluent API. + */ + public static class Builder { + + private Scheduler scheduler = Schedulers.boundedElastic(); + + private Duration initialDelay = Duration.ofSeconds(0); + + private Duration interval = Duration.ofSeconds(30); + + private Supplier> mcpSessions; + + /** + * Creates a new Builder instance with a supplier for McpSession instances. + * @param mcpSessions The supplier for McpSession instances + */ + Builder(Supplier> mcpSessions) { + Assert.notNull(mcpSessions, "McpSessions supplier must not be null"); + this.mcpSessions = mcpSessions; + } + + /** + * Sets the scheduler to use for executing keepAlive calls. + * @param scheduler The scheduler to use: + *
      + *
    • Schedulers.single() - single-threaded scheduler
    • + *
    • Schedulers.boundedElastic() - bounded elastic scheduler for I/O operations + * (Default)
    • + *
    • Schedulers.parallel() - parallel scheduler for CPU-intensive + * operations
    • + *
    • Schedulers.immediate() - immediate scheduler for synchronous execution
    • + *
    + * @return This builder instance for method chaining + */ + public Builder scheduler(Scheduler scheduler) { + Assert.notNull(scheduler, "Scheduler must not be null"); + this.scheduler = scheduler; + return this; + } + + /** + * Sets the initial delay before the first keepAlive call. + * @param initialDelay The initial delay duration + * @return This builder instance for method chaining + */ + public Builder initialDelay(Duration initialDelay) { + Assert.notNull(initialDelay, "Initial delay must not be null"); + this.initialDelay = initialDelay; + return this; + } + + /** + * Sets the interval between subsequent keepAlive calls. + * @param interval The interval duration + * @return This builder instance for method chaining + */ + public Builder interval(Duration interval) { + Assert.notNull(interval, "Interval must not be null"); + this.interval = interval; + return this; + } + + /** + * Builds and returns a new KeepAliveScheduler instance. + * @return A new KeepAliveScheduler configured with the builder's settings + */ + public KeepAliveScheduler build() { + return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java new file mode 100644 index 000000000..19569e49f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +/** + * Interface for working with URI templates. + *

    + * This interface provides methods for extracting variables from URI templates and + * matching them against actual URIs. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManager { + + /** + * Extract URI variable names from this URI template. + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + List getVariableNames(); + + /** + * Extract URI variable values from the actual request URI. + *

    + * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param uri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + Map extractVariableValues(String uri); + + /** + * Indicate whether the given URI matches this template. + * @param uri the URI to match to + * @return {@code true} if it matches; {@code false} otherwise + */ + boolean matches(String uri); + + /** + * Check if the given URI is a URI template. + * @return Returns true if the URI contains variables in the format {variableName} + */ + public boolean isUriTemplate(String uri); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java new file mode 100644 index 000000000..389727b45 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java @@ -0,0 +1,23 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.util; + +/** + * Factory interface for creating instances of {@link McpUriTemplateManager}. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + McpUriTemplateManager create(String uriTemplate); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java new file mode 100644 index 000000000..cd420100c --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.net.URI; +import java.util.Collection; +import java.util.Map; + +import reactor.util.annotation.Nullable; + +/** + * Miscellaneous utility methods. + * + * @author Christian Tzolov + */ + +public final class Utils { + + /** + * Check whether the given {@code String} contains actual text. + *

    + * More specifically, this method returns {@code true} if the {@code String} is not + * {@code null}, its length is greater than 0, and it contains at least one + * non-whitespace character. + * @param str the {@code String} to check (may be {@code null}) + * @return {@code true} if the {@code String} is not {@code null}, its length is + * greater than 0, and it does not contain whitespace only + * @see Character#isWhitespace + */ + public static boolean hasText(@Nullable String str) { + return (str != null && !str.isBlank()); + } + + /** + * Return {@code true} if the supplied Collection is {@code null} or empty. Otherwise, + * return {@code false}. + * @param collection the Collection to check + * @return whether the given Collection is empty + */ + public static boolean isEmpty(@Nullable Collection collection) { + return (collection == null || collection.isEmpty()); + } + + /** + * Return {@code true} if the supplied Map is {@code null} or empty. Otherwise, return + * {@code false}. + * @param map the Map to check + * @return whether the given Map is empty + */ + public static boolean isEmpty(@Nullable Map map) { + return (map == null || map.isEmpty()); + } + + /** + * Resolves the given endpoint URL against the base URL. + *

      + *
    • If the endpoint URL is relative, it will be resolved against the base URL.
    • + *
    • If the endpoint URL is absolute, it will be validated to ensure it matches the + * base URL's scheme, authority, and path prefix.
    • + *
    • If validation fails for an absolute URL, an {@link IllegalArgumentException} is + * thrown.
    • + *
    + * @param baseUrl The base URL (must be absolute) + * @param endpointUrl The endpoint URL (can be relative or absolute) + * @return The resolved endpoint URI + * @throws IllegalArgumentException If the absolute endpoint URL does not match the + * base URL or URI is malformed + */ + public static URI resolveUri(URI baseUrl, String endpointUrl) { + if (!Utils.hasText(endpointUrl)) { + return baseUrl; + } + URI endpointUri = URI.create(endpointUrl); + if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { + throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); + } + else { + return baseUrl.resolve(endpointUri); + } + } + + /** + * Checks if the given absolute endpoint URI falls under the base URI. It validates + * the scheme, authority (host and port), and ensures that the base path is a prefix + * of the endpoint path. + * @param baseUri The base URI + * @param endpointUri The endpoint URI to check + * @return true if endpointUri is within baseUri's hierarchy, false otherwise + */ + private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) { + if (!baseUri.getScheme().equals(endpointUri.getScheme()) + || !baseUri.getAuthority().equals(endpointUri.getAuthority())) { + return false; + } + + URI normalizedBase = baseUri.normalize(); + URI normalizedEndpoint = endpointUri.normalize(); + + String basePath = normalizedBase.getPath(); + String endpointPath = normalizedEndpoint.getPath(); + + if (basePath.endsWith("/")) { + basePath = basePath.substring(0, basePath.length() - 1); + } + return endpointPath.startsWith(basePath); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java new file mode 100644 index 000000000..8f68f0d6e --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManager; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link McpUriTemplateManager} and its implementations. + * + * @author Christian Tzolov + */ +public class McpUriTemplateManagerTests { + + private McpUriTemplateManagerFactory uriTemplateFactory; + + @BeforeEach + void setUp() { + this.uriTemplateFactory = new DefaultMcpUriTemplateManagerFactory(); + } + + @Test + void shouldExtractVariableNamesFromTemplate() { + List variables = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .getVariableNames(); + assertEquals(2, variables.size()); + assertEquals("userId", variables.get(0)); + assertEquals("postId", variables.get(1)); + } + + @Test + void shouldReturnEmptyListWhenTemplateHasNoVariables() { + List variables = this.uriTemplateFactory.create("/api/users/all").getVariableNames(); + assertEquals(0, variables.size()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromNullTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create(null).getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromEmptyTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create("").getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenTemplateContainsDuplicateVariables() { + assertThrows(IllegalArgumentException.class, + () -> this.uriTemplateFactory.create("/api/users/{userId}/posts/{userId}").getVariableNames()); + } + + @Test + void shouldExtractVariableValuesFromRequestUri() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues("/api/users/123/posts/456"); + assertEquals(2, values.size()); + assertEquals("123", values.get("userId")); + assertEquals("456", values.get("postId")); + } + + @Test + void shouldReturnEmptyMapWhenTemplateHasNoVariables() { + Map values = this.uriTemplateFactory.create("/api/users/all") + .extractVariableValues("/api/users/all"); + assertEquals(0, values.size()); + } + + @Test + void shouldReturnEmptyMapWhenRequestUriIsNull() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues(null); + assertEquals(0, values.size()); + } + + @Test + void shouldMatchUriAgainstTemplatePattern() { + var uriTemplateManager = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}"); + + assertTrue(uriTemplateManager.matches("/api/users/123/posts/456")); + assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); + } + + @Test + void shouldMatchUriWithQueryParameters() { + String templateWithQuery = "file://name/search?={search}"; + var uriTemplateManager = this.uriTemplateFactory.create(templateWithQuery); + + assertTrue(uriTemplateManager.matches("file://name/search?=abcd"), + "Should correctly match a URI containing query parameters."); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java similarity index 70% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index d4e48ea7d..9854de210 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -9,37 +9,46 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} - * interfaces. + * A mock implementation of the {@link McpClientTransport} interfaces. */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { +public class MockMcpClientTransport implements McpClientTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); private final List sent = new ArrayList<>(); - private final BiConsumer interceptor; + private final BiConsumer interceptor; - public MockMcpTransport() { + private String protocolVersion = McpSchema.LATEST_PROTOCOL_VERSION; + + public MockMcpClientTransport() { this((t, msg) -> { }); } - public MockMcpTransport(BiConsumer interceptor) { + public MockMcpClientTransport(BiConsumer interceptor) { this.interceptor = interceptor; } + public MockMcpClientTransport withProtocolVersion(String protocolVersion) { + return this; + } + + @Override + public List protocolVersions() { + return List.of(protocolVersion); + } + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { if (inbound.tryEmitNext(message).isFailure()) { throw new RuntimeException("Failed to process incoming message " + message); @@ -90,8 +99,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java new file mode 100644 index 000000000..f3d6b77a7 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpServerTransport; +import reactor.core.publisher.Mono; + +/** + * A mock implementation of the {@link McpServerTransport} interfaces. + */ +public class MockMcpServerTransport implements McpServerTransport { + + private final List sent = new ArrayList<>(); + + private final BiConsumer interceptor; + + public MockMcpServerTransport() { + this((t, msg) -> { + }); + } + + public MockMcpServerTransport(BiConsumer interceptor) { + this.interceptor = interceptor; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + sent.add(message); + interceptor.accept(this, message); + return Mono.empty(); + } + + public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { + return (JSONRPCRequest) getLastSentMessage(); + } + + public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() { + return (JSONRPCNotification) getLastSentMessage(); + } + + public McpSchema.JSONRPCMessage getLastSentMessage() { + return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; + } + + public void clearSentMessages() { + sent.clear(); + } + + public List getAllSentMessages() { + return new ArrayList<>(sent); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java new file mode 100644 index 000000000..e955be89f --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -0,0 +1,50 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerSession.Factory; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import reactor.core.publisher.Mono; + +/** + * @author Christian Tzolov + */ +public class MockMcpServerTransportProvider implements McpServerTransportProvider { + + private McpServerSession session; + + private final MockMcpServerTransport transport; + + public MockMcpServerTransportProvider(MockMcpServerTransport transport) { + this.transport = transport; + } + + public MockMcpServerTransport getTransport() { + return transport; + } + + @Override + public void setSessionFactory(Factory sessionFactory) { + + session = sessionFactory.create(transport); + } + + @Override + public Mono notifyClients(String method, Object params) { + return session.sendNotification(method, params); + } + + @Override + public Mono closeGracefully() { + return session.closeGracefully(); + } + + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { + session.handle(message).subscribe(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..18a5cb999 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,231 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java new file mode 100644 index 000000000..5b7877971 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -0,0 +1,829 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; +import io.modelcontextprotocol.spec.McpTransport; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +/** + * Test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncClientTests { + + private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; + + abstract protected McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(20); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .sampling(req -> Mono.just(new CreateMessageResult(McpSchema.Role.USER, + new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN))) + .capabilities(ClientCapabilities.builder().roots(true).sampling().build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request timeout must not be null"); + } + + @Test + void testListToolsWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(McpSchema.FIRST_PAGE), "listing tools"); + } + + @Test + void testListTools() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(McpSchema.FIRST_PAGE))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllTools() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllToolsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + + @Test + void testPingWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); + } + + @Test + void testCallToolWithoutInitialization() { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); + } + + @Test + void testCallToolWithInvalidTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); + } + + @ParameterizedTest + @ValueSource(strings = { "success", "error", "debug" }) + void testCallToolWithMessageAnnotations(String messageType) { + McpClientTransport transport = createMcpTransport(); + + withClient(transport, mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.callTool(new McpSchema.CallToolRequest("annotatedMessage", + Map.of("messageType", messageType, "includeImage", true))))) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isNotEqualTo(true); + assertThat(result.content()).isNotEmpty(); + assertThat(result.content()).allSatisfy(content -> { + switch (content.type()) { + case "text": + McpSchema.TextContent textContent = assertInstanceOf(McpSchema.TextContent.class, + content); + assertThat(textContent.text()).isNotEmpty(); + assertThat(textContent.annotations()).isNotNull(); + + switch (messageType) { + case "error": + assertThat(textContent.annotations().priority()).isEqualTo(1.0); + assertThat(textContent.annotations().audience()) + .containsOnly(McpSchema.Role.USER, McpSchema.Role.ASSISTANT); + break; + case "success": + assertThat(textContent.annotations().priority()).isEqualTo(0.7); + assertThat(textContent.annotations().audience()) + .containsExactly(McpSchema.Role.USER); + break; + case "debug": + assertThat(textContent.annotations().priority()).isEqualTo(0.3); + assertThat(textContent.annotations().audience()) + .containsExactly(McpSchema.Role.ASSISTANT); + break; + default: + throw new IllegalStateException("Unexpected value: " + content.type()); + } + break; + case "image": + McpSchema.ImageContent imageContent = assertInstanceOf(McpSchema.ImageContent.class, + content); + assertThat(imageContent.data()).isNotEmpty(); + assertThat(imageContent.annotations()).isNotNull(); + assertThat(imageContent.annotations().priority()).isEqualTo(0.5); + assertThat(imageContent.annotations().audience()).containsExactly(McpSchema.Role.USER); + break; + default: + fail("Unexpected content type: " + content.type()); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListResourcesWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(McpSchema.FIRST_PAGE), + "listing resources"); + } + + @Test + void testListResources() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(McpSchema.FIRST_PAGE))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResources() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResourcesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) + .consumeNextWith(result -> { + assertThat(result.resources()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy( + () -> result.resources().add(Resource.builder().uri("test://uri").name("test").build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + + @Test + void testMcpAsyncClientState() { + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); + } + + @Test + void testListPromptsWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(McpSchema.FIRST_PAGE), + "listing " + "prompts"); + } + + @Test + void testListPrompts() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(McpSchema.FIRST_PAGE))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllPrompts() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllPromptsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) + .consumeNextWith(result -> { + assertThat(result.prompts()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "test", "test", null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + + @Test + void testGetPromptWithoutInitialization() { + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); + } + + @Test + void testGetPrompt() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); + } + + @Test + void testRootsListChangedWithoutInitialization() { + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); + } + + @Test + void testRootsListChanged() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); + } + + @Test + void testInitializeWithRootsListProviders() { + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); + } + + @Test + void testAddRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); + } + + @Test + void testAddRootWithNullValue() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Root must not be null")) + .verify(); + }); + } + + @Test + void testRemoveRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); + + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); + } + + @Test + void testRemoveNonExistentRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); + } + + @Test + void testReadResource() { + AtomicInteger resourceCount = new AtomicInteger(); + withClient(createMcpTransport(), client -> { + Flux resources = client.initialize() + .then(client.listResources(null)) + .flatMapMany(r -> { + List l = r.resources(); + resourceCount.set(l.size()); + return Flux.fromIterable(l); + }) + .flatMap(r -> client.readResource(r)); + + StepVerifier.create(resources) + .recordWith(ArrayList::new) + .thenConsumeWhile(res -> true) + .consumeRecordedWith(readResourceResults -> { + assertThat(readResourceResults.size()).isEqualTo(resourceCount.get()); + for (ReadResourceResult result : readResourceResults) { + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, + content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + assertThat(textContent.uri()).isNotEmpty(); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, + content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + assertThat(blobContent.uri()).isNotNull().isNotEmpty(); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); + } + default -> { + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } + } + } + } + } + }) + .verifyComplete(); + }); + } + + @Test + void testListResourceTemplatesWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(McpSchema.FIRST_PAGE), + "listing resource templates"); + } + + @Test + void testListResourceTemplates() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates(McpSchema.FIRST_PAGE))) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResourceTemplates() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResourceTemplatesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result.resourceTemplates()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.resourceTemplates() + .add(new McpSchema.ResourceTemplate("test://template", "test", "test", null, null, null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + + // @Test + void testResourceSubscription() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); + + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); + } + + @Test + void testNotificationHandlers() { + AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); + AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()) + .expectNextMatches(Objects::nonNull) + .verifyComplete(); + }); + } + + @Test + void testInitializeWithSamplingCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + + @Test + void testInitializeWithAllCapabilities() { + var capabilities = ClientCapabilities.builder() + .experimental(Map.of("feature", Map.of("featureFlag", true))) + .roots(true) + .sampling() + .build(); + + Function> samplingHandler = request -> Mono + .just(CreateMessageResult.builder().message("test").model("test-model").build()); + + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), + client -> + + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); + } + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevelsWithoutInitialization() { + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); + } + + @Test + void testLoggingLevels() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) + .verifyComplete(); + }); + } + + @Test + void testLoggingConsumer() { + AtomicBoolean logReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); + + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); + } + + @Test + void testSampling() { + McpClientTransport transport = createMcpTransport(); + + final String message = "Hello, world!"; + final String response = "Goodbye, world!"; + final int maxTokens = 100; + + AtomicReference receivedPrompt = new AtomicReference<>(); + AtomicReference receivedMessage = new AtomicReference<>(); + AtomicInteger receivedMaxTokens = new AtomicInteger(); + + withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) + .sampling(request -> { + McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, + request.messages().get(0).content()); + receivedPrompt.set(request.systemPrompt()); + receivedMessage.set(messageText.text()); + receivedMaxTokens.set(request.maxTokens()); + + return Mono + .just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response), + "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN)); + }), client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool( + new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)))) + .consumeNextWith(result -> { + // Verify tool response to ensure our sampling response was passed + // through + assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); + assertThat(result.content()).allSatisfy(content -> { + if (!(content instanceof McpSchema.TextContent text)) + return; + + assertThat(text.text()).contains(response); + }); + + // Verify sampling request parameters received in our callback + assertThat(receivedPrompt.get()).isNotEmpty(); + assertThat(receivedMessage.get()).endsWith(message); // Prefixed + assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); + }) + .verifyComplete(); + }); + } + + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + List receivedNotifications = new CopyOnWriteArrayList<>(); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + receivedNotifications.add(notification); + sink.tryEmitNext(notification); + return Mono.empty(); + }), client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + StepVerifier.create(client.callTool(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + }).verifyComplete(); + + // Use StepVerifier to verify the progress notifications via the sink + StepVerifier.create(sink.asFlux()).expectNextCount(2).thenCancel().verify(Duration.ofSeconds(3)); + + assertThat(receivedNotifications).hasSize(2); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java new file mode 100644 index 000000000..c67fa86bb --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -0,0 +1,682 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ListResourceTemplatesResult; +import io.modelcontextprotocol.spec.McpSchema.ListResourcesResult; +import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +/** + * Unit tests for MCP Client Session functionality. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpSyncClientTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpSyncClientTests.class); + + private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; + + abstract protected McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpSyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.SyncSpec builder = McpClient.sync(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler + .subscribeOn(Schedulers.boundedElastic())).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request timeout must not be null"); + } + + @Test + void testListToolsWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(McpSchema.FIRST_PAGE), "listing tools"); + } + + @Test + void testListTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(McpSchema.FIRST_PAGE); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); + }); + } + + @Test + void testListAllTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); + }); + } + + @Test + void testCallToolsWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); + } + + @Test + void testCallTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + + assertThat(toolResult).isNotNull().satisfies(result -> { + + assertThat(result.content()).hasSize(1); + + TextContent content = (TextContent) result.content().get(0); + + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); + }); + } + + @Test + void testPingWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); + } + + @Test + void testCallToolWithoutInitialization() { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }); + } + + @Test + void testCallToolWithInvalidTool() { + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); + } + + @ParameterizedTest + @ValueSource(strings = { "success", "error", "debug" }) + void testCallToolWithMessageAnnotations(String messageType) { + McpClientTransport transport = createMcpTransport(); + + withClient(transport, client -> { + client.initialize(); + + McpSchema.CallToolResult result = client.callTool(new McpSchema.CallToolRequest("annotatedMessage", + Map.of("messageType", messageType, "includeImage", true))); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isNotEqualTo(true); + assertThat(result.content()).isNotEmpty(); + assertThat(result.content()).allSatisfy(content -> { + switch (content.type()) { + case "text": + McpSchema.TextContent textContent = assertInstanceOf(McpSchema.TextContent.class, content); + assertThat(textContent.text()).isNotEmpty(); + assertThat(textContent.annotations()).isNotNull(); + + switch (messageType) { + case "error": + assertThat(textContent.annotations().priority()).isEqualTo(1.0); + assertThat(textContent.annotations().audience()).containsOnly(McpSchema.Role.USER, + McpSchema.Role.ASSISTANT); + break; + case "success": + assertThat(textContent.annotations().priority()).isEqualTo(0.7); + assertThat(textContent.annotations().audience()).containsExactly(McpSchema.Role.USER); + break; + case "debug": + assertThat(textContent.annotations().priority()).isEqualTo(0.3); + assertThat(textContent.annotations().audience()) + .containsExactly(McpSchema.Role.ASSISTANT); + break; + default: + throw new IllegalStateException("Unexpected value: " + content.type()); + } + break; + case "image": + McpSchema.ImageContent imageContent = assertInstanceOf(McpSchema.ImageContent.class, content); + assertThat(imageContent.data()).isNotEmpty(); + assertThat(imageContent.annotations()).isNotNull(); + assertThat(imageContent.annotations().priority()).isEqualTo(0.5); + assertThat(imageContent.annotations().audience()).containsExactly(McpSchema.Role.USER); + break; + default: + fail("Unexpected content type: " + content.type()); + } + }); + }); + } + + @Test + void testRootsListChangedWithoutInitialization() { + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); + } + + @Test + void testRootsListChanged() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); + } + + @Test + void testListResourcesWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(McpSchema.FIRST_PAGE), + "listing resources"); + } + + @Test + void testListResources() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(McpSchema.FIRST_PAGE); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }); + } + + @Test + void testListAllResources() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }); + } + + @Test + void testClientSessionState() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); + } + + @Test + void testInitializeWithRootsListProviders() { + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { + + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testAddRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); + } + + @Test + void testAddRootWithNullValue() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); + } + + @Test + void testRemoveRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testRemoveNonExistentRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); + } + + @Test + void testReadResourceWithoutInitialization() { + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); + } + + @Test + void testReadResource() { + withClient(createMcpTransport(), mcpSyncClient -> { + + int readResourceCount = 0; + + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull(); + assertThat(resources.resources()).isNotNull(); + + assertThat(resources.resources()).isNotNull().isNotEmpty(); + + // Test reading each resource individually for better error isolation + for (Resource resource : resources.resources()) { + ReadResourceResult result = mcpSyncClient.readResource(resource); + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + readResourceCount++; + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + // Verify URI consistency + assertThat(textContent.uri()).isEqualTo(resource.uri()); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + // Verify URI consistency + assertThat(blobContent.uri()).isEqualTo(resource.uri()); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); + } + default -> { + // More flexible handling of additional MIME types + // Log the unexpected type for debugging but don't fail + // the test + logger.warn("Warning: Encountered unexpected MIME type: {} for resource: {}", + content.mimeType(), resource.uri()); + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } + } + } + } + } + + // Assert that we read exactly 10 resources + assertThat(readResourceCount).isEqualTo(10); + }); + } + + @Test + void testListResourceTemplatesWithoutInitialization() { + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(McpSchema.FIRST_PAGE), + "listing resource templates"); + } + + @Test + void testListResourceTemplates() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(McpSchema.FIRST_PAGE); + + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); + } + + @Test + void testListAllResourceTemplates() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(); + + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); + } + + // @Test + void testResourceSubscription() { + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); + + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); + } + + @Test + void testNotificationHandlers() { + AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); + AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), + client -> { + + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevelsWithoutInitialization() { + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); + } + + @Test + void testLoggingLevels() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); + } + + @Test + void testLoggingConsumer() { + AtomicBoolean logReceived = new AtomicBoolean(false); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); + } + + @Test + void testSampling() { + McpClientTransport transport = createMcpTransport(); + + final String message = "Hello, world!"; + final String response = "Goodbye, world!"; + final int maxTokens = 100; + + AtomicReference receivedPrompt = new AtomicReference<>(); + AtomicReference receivedMessage = new AtomicReference<>(); + AtomicInteger receivedMaxTokens = new AtomicInteger(); + + withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) + .sampling(request -> { + McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, + request.messages().get(0).content()); + receivedPrompt.set(request.systemPrompt()); + receivedMessage.set(messageText.text()); + receivedMaxTokens.set(request.maxTokens()); + + return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response), + "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN); + }), client -> { + client.initialize(); + + McpSchema.CallToolResult result = client.callTool( + new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))); + + // Verify tool response to ensure our sampling response was passed through + assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); + assertThat(result.content()).allSatisfy(content -> { + if (!(content instanceof McpSchema.TextContent text)) + return; + + assertThat(text.text()).contains(response); + }); + + // Verify sampling request parameters received in our callback + assertThat(receivedPrompt.get()).isNotEmpty(); + assertThat(receivedMessage.get()).endsWith(message); // Prefixed + assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); + }); + } + + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + AtomicInteger progressNotificationCount = new AtomicInteger(0); + List receivedNotifications = new CopyOnWriteArrayList<>(); + CountDownLatch latch = new CountDownLatch(2); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + System.out.println("Received progress notification: " + notification); + receivedNotifications.add(notification); + progressNotificationCount.incrementAndGet(); + latch.countDown(); + }), client -> { + client.initialize(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + CallToolResult result = client.callTool(request); + + assertThat(result).isNotNull(); + + try { + // Wait for progress notifications to be processed + latch.await(3, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertThat(progressNotificationCount.get()).isEqualTo(2); + + assertThat(receivedNotifications).isNotEmpty(); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..945278154 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import reactor.test.StepVerifier; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Test + void testPingWithExactExceptionType() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..a29ca16db --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + private static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..ee5e5de05 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@Timeout(15) +public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(URI.create(host + "/mcp")), + eq("{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}"), + eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java new file mode 100644 index 000000000..e2037f415 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java @@ -0,0 +1,139 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import static org.assertj.core.api.Assertions.assertThatCode; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.test.StepVerifier; + +@Timeout(20) +public class HttpSseMcpAsyncClientLostConnectionTests { + + private static final Logger logger = LoggerFactory.getLogger(HttpSseMcpAsyncClientLostConnectionTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + McpAsyncClient client(McpClientTransport transport) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(Duration.ofSeconds(14)) + .initializationTimeout(Duration.ofSeconds(2)) + .capabilities(McpSchema.ClientCapabilities.builder().build()); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + var client = client(transport); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPingWithExactExceptionType() { + withClient(HttpClientSseClientTransport.builder(host).build(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + // Veryfiy that the exception type is IOException and not TimeoutException + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 53% rename from mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 2b8af41af..91a8b6c82 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,51 +4,47 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + /** * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. * * @author Christian Tzolov */ -@Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncClientTests extends AbstractMcpSyncClientTests { +@Timeout(15) +class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - String host = "http://localhost:3003"; + private static String host = "http://localhost:3004"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + protected McpClientTransport createMcpTransport() { + return HttpClientSseClientTransport.builder(host).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java new file mode 100644 index 000000000..d903b3b3c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3003"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientSseClientTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("GET"), eq(URI.create(host + "/sse")), + isNull(), eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java new file mode 100644 index 000000000..6f7390f19 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer} postInitializationHook functionality. + * + * @author Christian Tzolov + */ +class LifecycleInitializerPostInitializationHookTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + @Mock + private Function> mockPostInitializationHook; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + } + + @Test + void shouldInvokePostInitializationHook() { + AtomicReference capturedInit = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + capturedInit.set(invocation.getArgument(0)); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + + // Verify the hook received correct initialization data + assertThat(capturedInit.get()).isNotNull(); + assertThat(capturedInit.get().mcpSession()).isEqualTo(mockClientSession); + assertThat(capturedInit.get().initializeResult()).isEqualTo(MOCK_INIT_RESULT); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnce() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse initialization and NOT call hook again + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Hook should only be called once + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnceWithConcurrentRequests() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // Start multiple concurrent initializations + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + // TODO: can we assume the order of results? + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Hook should only be called once despite concurrent requests + assertThat(hookInvocationCount.get()).isEqualTo(1); + } + + @Test + void shouldFailInitializationWhenPostInitializationHookFails() { + RuntimeException hookError = new RuntimeException("Post-initialization hook failed"); + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.error(hookError)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectErrorMatches(ex -> ex instanceof RuntimeException && ex.getCause() == hookError) + .verify(); + + // Verify initialization was not completed + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + + // Verify the hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenInitializationFails() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.error(new RuntimeException("Initialization failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when initialization fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenNotificationFails() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when notification fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookAgainAfterReinitialization() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + assertThat(hookInvocationCount.get()).isEqualTo(1); + + // Simulate transport session exception to trigger re-initialization + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Hook should be called twice (once for each initialization) + assertThat(hookInvocationCount.get()).isEqualTo(2); + } + + @Test + void shouldAllowPostInitializationHookToPerformAsyncOperations() { + AtomicInteger operationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))) + .thenReturn(Mono.fromRunnable(() -> operationCount.incrementAndGet()).then()); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the async operation was executed + assertThat(operationCount.get()).isEqualTo(1); + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldProvideCorrectInitializationDataToHook() { + AtomicReference capturedSession = new AtomicReference<>(); + AtomicReference capturedResult = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + Initialization init = invocation.getArgument(0); + capturedSession.set(init.mcpSession()); + capturedResult.set(init.initializeResult()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook received the correct session and result + assertThat(capturedSession.get()).isEqualTo(mockClientSession); + assertThat(capturedResult.get()).isEqualTo(MOCK_INIT_RESULT); + assertThat(capturedResult.get().protocolVersion()).isEqualTo("2.0.0"); + assertThat(capturedResult.get().serverInfo().name()).isEqualTo("test-server"); + } + + @Test + void shouldInvokePostInitializationHookAfterSuccessfulInitialization() { + AtomicReference notificationSent = new AtomicReference<>(false); + AtomicReference hookCalledAfterNotification = new AtomicReference<>(false); + + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenAnswer(invocation -> { + notificationSent.set(true); + return Mono.empty(); + }); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + // Due to flatMap chaining in doInitialize, if the hook is called, + // the notification must have been sent first + hookCalledAfterNotification.set(notificationSent.get()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook was called and notification was already sent at that point + assertThat(hookCalledAfterNotification.get()).isTrue(); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + verify(mockPostInitializationHook).apply(any(Initialization.class)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java new file mode 100644 index 000000000..787ee9480 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -0,0 +1,427 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer}. + */ +class LifecycleInitializerTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + @Mock + private Function> mockPostInitializationHook; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + } + + @Test + void constructorShouldValidateParameters() { + assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT, + mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Client capabilities must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Client info must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Protocol versions must not be empty"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(), + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Protocol versions must not be empty"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null, + mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Initialization timeout must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, null, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Session supplier must not be null"); + } + + @Test + void shouldInitializeSuccessfully() { + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + assertThat(result).isEqualTo(MOCK_INIT_RESULT); + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); + }) + .verifyComplete(); + + verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class), + any()); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); + } + + @Test + void shouldUseLatestProtocolVersionInInitializeRequest() { + AtomicReference capturedRequest = new AtomicReference<>(); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return Mono.just(MOCK_INIT_RESULT); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + assertThat(capturedRequest.get().protocolVersion()).isEqualTo("2.0.0"); // Latest + // version + assertThat(capturedRequest.get().capabilities()).isEqualTo(CLIENT_CAPABILITIES); + assertThat(capturedRequest.get().clientInfo()).isEqualTo(CLIENT_INFO); + }) + .verifyComplete(); + } + + @Test + void shouldFailForUnsupportedProtocolVersion() { + McpSchema.InitializeResult unsupportedResult = new McpSchema.InitializeResult("999.0.0", // Unsupported + // version + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(unsupportedResult)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + verify(mockClientSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + } + + @Test + void shouldTimeoutOnSlowInitialization() { + VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + + Duration INITIALIZE_TIMEOUT = Duration.ofSeconds(1); + Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5); + + LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, + PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler)); + + StepVerifier + .withVirtualTime(() -> shortTimeoutInitializer.withInitialization("test", + init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE) + .expectSubscription() + .expectNoEvent(INITIALIZE_TIMEOUT) + .expectError(RuntimeException.class) + .verify(); + } + + @Test + void shouldReuseExistingInitialization() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse the same initialization + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Verify session was created only once + verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); + verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + + @Test + void shouldHandleConcurrentInitializationRequests() { + AtomicInteger sessionCreationCount = new AtomicInteger(0); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> { + sessionCreationCount.incrementAndGet(); + return mockClientSession; + }); + + // Start multiple concurrent initializations using subscribeOn with parallel + // scheduler + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Should only create one session despite concurrent requests + assertThat(sessionCreationCount.get()).isEqualTo(1); + verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + + @Test + void shouldHandleInitializationFailure() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + // fail once + .thenReturn(Mono.error(new RuntimeException("Connection failed"))) + // succeeds on the second call + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + + // The initializer can recover from previous errors + StepVerifier + .create(initializer.withInitialization("successful init", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); + } + + @Test + void shouldHandleTransportSessionNotFoundException() { + // successful initialization first + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + + // Simulate transport session not found + initializer.handleException(new McpTransportSessionNotFoundException("Session not found")); + + assertThat(initializer.isInitialized()).isTrue(); + + // Verify that the session was closed and re-initialized + verify(mockClientSession).close(); + + // Verify session was created 2 times (once for initial and once for + // re-initialization) + verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); + } + + @Test + void shouldHandleOtherExceptions() { + // Simulate a successful initialization first + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + + // Simulate other exception (should not trigger re-initialization) + initializer.handleException(new RuntimeException("Some other error")); + + // Should still be initialized + assertThat(initializer.isInitialized()).isTrue(); + verify(mockClientSession, never()).close(); + // Verify that the session was not re-created + verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); + } + + @Test + void shouldCloseGracefully() { + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + StepVerifier.create(initializer.closeGracefully()).verifyComplete(); + + verify(mockClientSession).closeGracefully(); + assertThat(initializer.isInitialized()).isFalse(); + } + + @Test + void shouldCloseImmediately() { + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Close immediately + initializer.close(); + + verify(mockClientSession).close(); + assertThat(initializer.isInitialized()).isFalse(); + } + + @Test + void shouldHandleCloseWithoutInitialization() { + // Close without initialization should not throw + initializer.close(); + + StepVerifier.create(initializer.closeGracefully()).verifyComplete(); + + verify(mockClientSession, never()).close(); + verify(mockClientSession, never()).closeGracefully(); + } + + @Test + void shouldSetProtocolVersionsForTesting() { + List newVersions = List.of("3.0.0", "4.0.0"); + initializer.setProtocolVersions(newVersions); + + AtomicReference capturedRequest = new AtomicReference<>(); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return Mono.just(new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(), + new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions")); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + // Latest from new versions + assertThat(capturedRequest.get().protocolVersion()).isEqualTo("4.0.0"); + }) + .verifyComplete(); + } + + @Test + void shouldPassContextToSessionSupplier() { + String contextKey = "test.key"; + String contextValue = "test.value"; + + AtomicReference capturedContext = new AtomicReference<>(); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> { + capturedContext.set(invocation.getArgument(0)); + return mockClientSession; + }); + + StepVerifier + .create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult())) + .contextWrite(Context.of(contextKey, contextValue))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(capturedContext.get().hasKey(contextKey)).isTrue(); + assertThat((String) capturedContext.get().get(contextKey)).isEqualTo(contextValue); + } + + @Test + void shouldProvideAccessToMcpSessionAndInitializeResult() { + StepVerifier.create(initializer.withInitialization("test", init -> { + assertThat(init.mcpSession()).isEqualTo(mockClientSession); + assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT); + return Mono.just("success"); + })).expectNext("success").verifyComplete(); + } + + @Test + void shouldHandleNotificationFailure() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); + } + + @Test + void shouldReturnNullWhenNotInitialized() { + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + } + + @Test + void shouldReinitializeAfterTransportSessionException() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Simulate transport session exception + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Should be able to initialize again + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Verify two separate initializations occurred + verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); + verify(mockClientSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java similarity index 54% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index b1e82b748..612a65898 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -4,24 +4,26 @@ package io.modelcontextprotocol.client; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Function; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.MockMcpTransport; -import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -34,22 +36,22 @@ class McpAsyncClientResponseHandlerTests { .resources(true, true) // Enable both resources and resource templates .build(); - private static MockMcpTransport initializationEnabledTransport() { + private static MockMcpClientTransport initializationEnabledTransport() { return initializationEnabledTransport(SERVER_CAPABILITIES, SERVER_INFO); } - private static MockMcpTransport initializationEnabledTransport(McpSchema.ServerCapabilities mockServerCapabilities, - McpSchema.Implementation mockServerInfo) { + private static MockMcpClientTransport initializationEnabledTransport( + McpSchema.ServerCapabilities mockServerCapabilities, McpSchema.Implementation mockServerInfo) { McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, mockServerCapabilities, mockServerInfo, "Test instructions"); - return new MockMcpTransport((t, message) -> { + return new MockMcpClientTransport((t, message) -> { if (message instanceof McpSchema.JSONRPCRequest r && METHOD_INITIALIZE.equals(r.method())) { McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, r.id(), mockInitResult, null); t.simulateIncomingMessage(initResponse); } - }); + }).withProtocolVersion(McpSchema.LATEST_PROTOCOL_VERSION); } @Test @@ -59,7 +61,7 @@ void testSuccessfulInitialization() { .tools(false) .resources(true, true) // Enable both resources and resource templates .build(); - MockMcpTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo); + MockMcpClientTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo); McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); // Verify client is not initialized initially @@ -76,8 +78,9 @@ void testSuccessfulInitialization() { // Verify initialization result assertThat(result).isNotNull(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); assertThat(result.capabilities()).isEqualTo(serverCapabilities); + assertThat(result.capabilities().logging()).isNull(); assertThat(result.serverInfo()).isEqualTo(serverInfo); assertThat(result.instructions()).isEqualTo("Test instructions"); @@ -90,8 +93,8 @@ void testSuccessfulInitialization() { } @Test - void testToolsChangeNotificationHandling() throws JsonProcessingException { - MockMcpTransport transport = initializationEnabledTransport(); + void testToolsChangeNotificationHandling() throws IOException { + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received tools for verification List receivedTools = new ArrayList<>(); @@ -107,34 +110,64 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { // Create a mock tools list that the server will return Map inputSchema = Map.of("type", "object", "properties", Map.of(), "required", List.of()); - McpSchema.Tool mockTool = new McpSchema.Tool("test-tool", "Test Tool Description", - new ObjectMapper().writeValueAsString(inputSchema)); - McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(mockTool), null); + McpSchema.Tool mockTool = McpSchema.Tool.builder() + .name("test-tool-1") + .description("Test Tool 1 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); + + // Create page 1 response with nextPageToken + String nextPageToken = "page2Token"; + McpSchema.ListToolsResult mockToolsResult1 = new McpSchema.ListToolsResult(List.of(mockTool), nextPageToken); // Simulate server sending tools/list_changed notification McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); transport.simulateIncomingMessage(notification); - // Simulate server response to tools/list request - McpSchema.JSONRPCRequest toolsListRequest = transport.getLastSentMessageAsRequest(); - assertThat(toolsListRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_LIST); + // Simulate server response to first tools/list request + McpSchema.JSONRPCRequest toolsListRequest1 = transport.getLastSentMessageAsRequest(); + assertThat(toolsListRequest1.method()).isEqualTo(McpSchema.METHOD_TOOLS_LIST); + + McpSchema.JSONRPCResponse toolsListResponse1 = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + toolsListRequest1.id(), mockToolsResult1, null); + transport.simulateIncomingMessage(toolsListResponse1); + + // Create mock tools for page 2 + McpSchema.Tool mockTool2 = McpSchema.Tool.builder() + .name("test-tool-2") + .description("Test Tool 2 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); + // Create page 2 response with no nextPageToken (last page) + McpSchema.ListToolsResult mockToolsResult2 = new McpSchema.ListToolsResult(List.of(mockTool2), null); + + // Simulate server response to second tools/list request with page token + McpSchema.JSONRPCRequest toolsListRequest2 = transport.getLastSentMessageAsRequest(); + assertThat(toolsListRequest2.method()).isEqualTo(McpSchema.METHOD_TOOLS_LIST); - McpSchema.JSONRPCResponse toolsListResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, - toolsListRequest.id(), mockToolsResult, null); - transport.simulateIncomingMessage(toolsListResponse); + // Verify the page token was included in the request + PaginatedRequest params = (PaginatedRequest) toolsListRequest2.params(); + assertThat(params).isNotNull(); + assertThat(params.cursor()).isEqualTo(nextPageToken); - // Verify the consumer received the expected tools - assertThat(receivedTools).hasSize(1); - assertThat(receivedTools.get(0).name()).isEqualTo("test-tool"); - assertThat(receivedTools.get(0).description()).isEqualTo("Test Tool Description"); + McpSchema.JSONRPCResponse toolsListResponse2 = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + toolsListRequest2.id(), mockToolsResult2, null); + transport.simulateIncomingMessage(toolsListResponse2); + + // Verify the consumer received all expected tools from both pages + assertThat(receivedTools).hasSize(2); + assertThat(receivedTools.get(0).name()).isEqualTo("test-tool-1"); + assertThat(receivedTools.get(0).description()).isEqualTo("Test Tool 1 Description"); + assertThat(receivedTools.get(1).name()).isEqualTo("test-tool-2"); + assertThat(receivedTools.get(1).description()).isEqualTo("Test Tool 2 Description"); asyncMcpClient.closeGracefully(); } @Test void testRootsListRequestHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); McpAsyncClient asyncMcpClient = McpClient.async(transport) .roots(new Root("file:///test/path", "test-root")) @@ -162,7 +195,7 @@ void testRootsListRequestHandling() { @Test void testResourcesChangeNotificationHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received resources for verification List receivedResources = new ArrayList<>(); @@ -208,7 +241,7 @@ void testResourcesChangeNotificationHandling() { @Test void testPromptsChangeNotificationHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received prompts for verification List receivedPrompts = new ArrayList<>(); @@ -223,8 +256,8 @@ void testPromptsChangeNotificationHandling() { assertThat(asyncMcpClient.initialize().block()).isNotNull(); // Create a mock prompts list that the server will return - McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description", - List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt", "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", "Test argument", true))); McpSchema.ListPromptsResult mockPromptsResult = new McpSchema.ListPromptsResult(List.of(mockPrompt), null); // Simulate server sending prompts/list_changed notification @@ -252,7 +285,7 @@ void testPromptsChangeNotificationHandling() { @Test void testSamplingCreateMessageRequestHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a test sampling handler that echoes back the input Function> samplingHandler = request -> { @@ -293,7 +326,7 @@ void testSamplingCreateMessageRequestHandling() { assertThat(response.error()).isNull(); McpSchema.CreateMessageResult result = transport.unmarshalFrom(response.result(), - new TypeReference() { + new TypeRef() { }); assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); @@ -306,7 +339,7 @@ void testSamplingCreateMessageRequestHandling() { @Test void testSamplingCreateMessageRequestHandlingWithoutCapability() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create client without sampling capability McpAsyncClient asyncMcpClient = McpClient.async(transport) @@ -340,13 +373,187 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { @Test void testSamplingCreateMessageRequestHandlingWithNullHandler() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); // Create client with sampling capability but null handler assertThatThrownBy( () -> McpClient.async(transport).capabilities(ClientCapabilities.builder().sampling().build()).build()) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } + @Test + @SuppressWarnings("unchecked") + void testElicitationCreateRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler that echoes back the input + Function> elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isInstanceOf(Map.class); + assertThat(request.requestedSchema().get("type")).isEqualTo("object"); + + var properties = request.requestedSchema().get("properties"); + assertThat(properties).isNotNull(); + assertThat(((Map) properties).get("message")).isInstanceOf(Map.class); + + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + }; + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isEqualTo(Map.of("message", "Test message")); + + asyncMcpClient.closeGracefully(); + } + + @ParameterizedTest + @EnumSource(value = McpSchema.ElicitResult.Action.class, names = { "DECLINE", "CANCEL" }) + void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler to decline the request + Function> elicitationHandler = request -> Mono + .just(McpSchema.ElicitResult.builder().message(action).build()); + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(action); + assertThat(result.content()).isNull(); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithoutCapability() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client without elicitation capability + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().build()) // No elicitation + // capability + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = new McpSchema.ElicitRequest("test", + Map.of("type", "object", "properties", Map.of("test", Map.of("type", "boolean", "defaultValue", true, + "description", "test-description", "title", "test-title")))); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.result()).isNull(); + assertThat(response.error()).isNotNull(); + assertThat(response.error().message()).contains("Method not found: elicitation/create"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithNullHandler() { + MockMcpClientTransport transport = new MockMcpClientTransport(); + + // Create client with elicitation capability but null handler + assertThatThrownBy(() -> McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); + } + + @Test + void testPingMessageRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Simulate incoming ping request from server + McpSchema.JSONRPCRequest pingRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_PING, "ping-id", null); + transport.simulateIncomingMessage(pingRequest); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("ping-id"); + assertThat(response.error()).isNull(); + assertThat(response.result()).isInstanceOf(Map.class); + assertThat(((Map) response.result())).isEmpty(); + + asyncMcpClient.closeGracefully(); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java new file mode 100644 index 000000000..48bf1da5b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -0,0 +1,310 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +class McpAsyncClientTests { + + public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", + "1.0.0"); + + public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .tools(true) + .build(); + + public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( + ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); + + private static final String CONTEXT_KEY = "context.key"; + + private McpClientTransport createMockTransportForToolValidation(boolean hasOutputSchema, boolean invalidOutput) { + + // Create tool with or without output schema + Map inputSchemaMap = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + + McpSchema.JsonSchema inputSchema = new McpSchema.JsonSchema("object", inputSchemaMap, null, null, null, null); + McpSchema.Tool.Builder toolBuilder = McpSchema.Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(inputSchema); + + if (hasOutputSchema) { + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + toolBuilder.outputSchema(outputSchema); + } + + McpSchema.Tool calculatorTool = toolBuilder.build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result - valid or invalid based on parameter + Map structuredContent = invalidOutput ? Map.of("result", "5", "operation", "add") + : Map.of("result", 5, "operation", "add"); + + McpSchema.CallToolResult mockCallToolResult = McpSchema.CallToolResult.builder() + .addTextContent("Calculation result") + .structuredContent(structuredContent) + .build(); + + return new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + this.handler = handler; + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else if (McpSchema.METHOD_TOOLS_CALL.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + mockCallToolResult, null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + } + + @Test + void validateContextPassedToTransportConnect() { + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + final AtomicReference contextValue = new AtomicReference<>(); + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + if (ctx.hasKey(CONTEXT_KEY)) { + this.contextValue.set(ctx.get(CONTEXT_KEY)); + } + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!"hello".equals(this.contextValue.get())) { + return Mono.error(new RuntimeException("Context value not propagated via #connect method")); + } + // We're only interested in handling the init request to provide an init + // response + if (!(message instanceof McpSchema.JSONRPCRequest)) { + return Mono.empty(); + } + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); + return handler.apply(Mono.just(initResponse)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + assertThatCode(() -> { + McpAsyncClient client = McpClient.async(transport).build(); + client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); + }).doesNotThrowAnyException(); + } + + @Test + void testCallToolWithOutputSchemaValidationSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(true, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithNoOutputSchemaSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(false, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithOutputSchemaValidationFailure() { + McpClientTransport transport = createMockTransportForToolValidation(true, true); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectErrorMatches(ex -> ex instanceof IllegalArgumentException + && ex.getMessage().contains("Tool call result validation failed")) + .verify(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testListToolsWithEmptyCursor() { + McpSchema.Tool addTool = McpSchema.Tool.builder().name("add").description("calculate add").build(); + McpSchema.Tool subtractTool = McpSchema.Tool.builder() + .name("subtract") + .description("calculate subtract") + .build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(addTool, subtractTool), ""); + + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + Mono mono = client.listTools(); + McpSchema.ListToolsResult toolsResult = mono.block(); + assertThat(toolsResult).isNotNull(); + + Set names = toolsResult.tools().stream().map(McpSchema.Tool::name).collect(Collectors.toSet()); + assertThat(names).containsExactlyInAnyOrder("subtract", "add"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java similarity index 81% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index 58e486e19..a94b9b6a7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -7,10 +7,10 @@ import java.time.Duration; import java.util.List; -import io.modelcontextprotocol.MockMcpTransport; -import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -22,13 +22,13 @@ */ class McpClientProtocolVersionTests { - private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(30); + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(300); private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); @Test void shouldUseLatestVersionByDefault() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -37,20 +37,21 @@ void shouldUseLatestVersionByDefault() { try { Mono initializeResultMono = client.initialize(); + String protocolVersion = transport.protocolVersions().get(transport.protocolVersions().size() - 1); + StepVerifier.create(initializeResultMono).then(() -> { McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params(); - assertThat(initRequest.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(initRequest.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, null, + new McpSchema.InitializeResult(protocolVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); }).verifyComplete(); - } finally { // Ensure cleanup happens even if test fails @@ -61,7 +62,7 @@ void shouldUseLatestVersionByDefault() { @Test void shouldNegotiateSpecificVersion() { String oldVersion = "0.1.0"; - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -79,7 +80,7 @@ void shouldNegotiateSpecificVersion() { assertThat(initRequest.protocolVersion()).isIn(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(oldVersion, null, + new McpSchema.InitializeResult(oldVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { @@ -94,7 +95,7 @@ void shouldNegotiateSpecificVersion() { @Test void shouldFailForUnsupportedVersion() { String unsupportedVersion = "999.999.999"; - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -108,10 +109,10 @@ void shouldFailForUnsupportedVersion() { assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(unsupportedVersion, null, + new McpSchema.InitializeResult(unsupportedVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); - }).expectError(McpError.class).verify(); + }).expectError(RuntimeException.class).verify(); } finally { StepVerifier.create(client.closeGracefully()).verifyComplete(); @@ -124,7 +125,7 @@ void shouldUseHighestVersionWhenMultipleSupported() { String middleVersion = "0.2.0"; String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -141,7 +142,7 @@ void shouldUseHighestVersionWhenMultipleSupported() { assertThat(initRequest.protocolVersion()).isEqualTo(latestVersion); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(latestVersion, null, + new McpSchema.InitializeResult(latestVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java new file mode 100644 index 000000000..547ccc52f --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java @@ -0,0 +1,21 @@ +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.ServerParameters; + +public final class ServerParameterUtils { + + private ServerParameterUtils() { + } + + public static ServerParameters createServerParameters() { + if (System.getProperty("os.name").toLowerCase().contains("win")) { + return ServerParameters.builder("cmd.exe") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything@2025.12.18", "stdio") + .build(); + } + return ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything@2025.12.18", "stdio") + .build(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java new file mode 100644 index 000000000..aa8aaa397 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; + +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; + +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + +/** + * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. + * + *

    + * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download +class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + @Override + protected McpClientTransport createMcpTransport() { + return new StdioClientTransport(createServerParameters(), JSON_MAPPER); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(20); + } + + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java new file mode 100644 index 000000000..b1e567989 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. + * + *

    + * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download +class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { + + @Override + protected McpClientTransport createMcpTransport() { + ServerParameters stdioParams = createServerParameters(); + return new StdioClientTransport(stdioParams, JSON_MAPPER); + } + + @Test + void customErrorHandlerShouldReceiveErrors() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference receivedError = new AtomicReference<>(); + + McpClientTransport transport = createMcpTransport(); + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + ((StdioClientTransport) transport).setStdErrorHandler(error -> { + receivedError.set(error); + latch.countDown(); + }); + + String errorMessage = "Test error"; + ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + + assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); + + StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5)); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(10); + } + + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java new file mode 100644 index 000000000..a24805a30 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -0,0 +1,480 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.util.UriComponentsBuilder; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the {@link HttpClientSseClientTransport} class. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpClientSseClientTransportTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private TestHttpClientSseClientTransport transport; + + private final McpTransportContext context = McpTransportContext.create(Map.of("some-key", "some-value")); + + // Test class to access protected methods + static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { + + private final AtomicInteger inboundMessageCount = new AtomicInteger(0); + + private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); + + public TestHttpClientSseClientTransport(final String baseUri) { + super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), + HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, + McpAsyncHttpClientRequestCustomizer.NOOP); + } + + public int getInboundMessageCount() { + return inboundMessageCount.get(); + } + + public void simulateEndpointEvent(String jsonMessage) { + events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); + inboundMessageCount.incrementAndGet(); + } + + public void simulateMessageEvent(String jsonMessage) { + events.tryEmitNext(ServerSentEvent.builder().event("message").data(jsonMessage).build()); + inboundMessageCount.incrementAndGet(); + } + + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @BeforeEach + void setUp() { + transport = new TestHttpClientSseClientTransport(host); + transport.connect(Function.identity()).block(); + } + + @AfterEach + void afterEach() { + if (transport != null) { + assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + } + + @Test + void testErrorOnBogusMessage() { + // bogus message + JSONRPCRequest bogusMessage = new JSONRPCRequest(null, null, "test-id", Map.of("key", "value")); + + StepVerifier.create(transport.sendMessage(bogusMessage)) + .verifyErrorMessage( + "Sending message failed with a non-OK HTTP code: 400 - Invalid message: {\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"); + } + + @Test + void testMessageProcessing() { + // Create a test message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Simulate receiving the message + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "test-method", + "id": "test-id", + "params": {"key": "value"} + } + """); + + // Subscribe to messages and verify + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testResponseMessageProcessing() { + // Simulate receiving a response message + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "id": "test-id", + "result": {"status": "success"} + } + """); + + // Create and send a request message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Verify message handling + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testErrorMessageProcessing() { + // Simulate receiving an error message + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "id": "test-id", + "error": { + "code": -32600, + "message": "Invalid Request" + } + } + """); + + // Create and send a request message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Verify message handling + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testNotificationMessageProcessing() { + // Simulate receiving a notification message (no id) + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "update", + "params": {"status": "processing"} + } + """); + + // Verify the notification was processed + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testGracefulShutdown() { + // Test graceful shutdown + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + // Create a test message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Verify message is not processed after shutdown + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Message count should remain 0 after shutdown + assertThat(transport.getInboundMessageCount()).isZero(); + } + + @Test + void testRetryBehavior() { + // Create a client that simulates connection failures + HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host") + .build(); + + // Verify that the transport attempts to reconnect + StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); + + // Clean up + failingTransport.closeGracefully().block(); + } + + @Test + void testMultipleMessageProcessing() { + // Simulate receiving multiple messages in sequence + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "method1", + "id": "id1", + "params": {"key": "value1"} + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "method2", + "id": "id2", + "params": {"key": "value2"} + } + """); + + // Create and send corresponding messages + JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + Map.of("key", "value1")); + + JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + Map.of("key", "value2")); + + // Verify both messages are processed + StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete(); + + // Verify message count + assertThat(transport.getInboundMessageCount()).isEqualTo(2); + } + + @Test + void testMessageOrderPreservation() { + // Simulate receiving messages in a specific order + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "first", + "id": "1", + "params": {"sequence": 1} + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "second", + "id": "2", + "params": {"sequence": 2} + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "third", + "id": "3", + "params": {"sequence": 3} + } + """); + + // Verify message count and order + assertThat(transport.getInboundMessageCount()).isEqualTo(3); + } + + @Test + void testCustomizeClient() { + // Create an atomic boolean to verify the customizer was called + AtomicBoolean customizerCalled = new AtomicBoolean(false); + + // Create a transport with the customizer + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + .customizeClient(builder -> { + builder.version(HttpClient.Version.HTTP_2); + customizerCalled.set(true); + }) + .build(); + + // Verify the customizer was called + assertThat(customizerCalled.get()).isTrue(); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testCustomizeRequest() { + // Create an atomic boolean to verify the customizer was called + AtomicBoolean customizerCalled = new AtomicBoolean(false); + + // Create a reference to store the custom header value + AtomicReference headerName = new AtomicReference<>(); + AtomicReference headerValue = new AtomicReference<>(); + + // Create a transport with the customizer + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + // Create a request customizer that adds a custom header + .customizeRequest(builder -> { + builder.header("X-Custom-Header", "test-value"); + customizerCalled.set(true); + + // Create a new request to verify the header was set + HttpRequest request = builder.uri(URI.create("http://example.com")).build(); + headerName.set("X-Custom-Header"); + headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null)); + }) + .build(); + + // Verify the customizer was called + assertThat(customizerCalled.get()).isTrue(); + + // Verify the header was set correctly + assertThat(headerName.get()).isEqualTo("X-Custom-Header"); + assertThat(headerValue.get()).isEqualTo("test-value"); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testChainedCustomizations() { + // Create atomic booleans to verify both customizers were called + AtomicBoolean clientCustomizerCalled = new AtomicBoolean(false); + AtomicBoolean requestCustomizerCalled = new AtomicBoolean(false); + + // Create a transport with both customizers chained + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + .customizeClient(builder -> { + builder.connectTimeout(Duration.ofSeconds(30)); + clientCustomizerCalled.set(true); + }) + .customizeRequest(builder -> { + builder.header("X-Api-Key", "test-api-key"); + requestCustomizerCalled.set(true); + }) + .build(); + + // Verify both customizers were called + assertThat(clientCustomizerCalled.get()).isTrue(); + assertThat(requestCustomizerCalled.get()).isTrue(); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testRequestCustomizer() { + var mockCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .httpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testAsyncRequestCustomizer() { + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .asyncHttpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java new file mode 100644 index 000000000..81e642681 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Handles emplty application/json response with 200 OK status code. + * + * @author codezkk + */ +public class HttpClientStreamableHttpTransportEmptyJsonResponseTest { + + static int PORT = TomcatTestUtil.findAvailablePort(); + + static String host = "http://localhost:" + PORT; + + static HttpServer server; + + @BeforeAll + static void startContainer() throws IOException { + + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Empty, 200 OK response for the /mcp endpoint + server.createContext("/mcp", exchange -> { + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, 0); + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + } + + @AfterAll + static void stopContainer() { + server.stop(1); + } + + /** + * Regardless of the response (even if the response is null and the content-type is + * present), notify should handle it correctly. + */ + @Test + @Timeout(3) + void testNotificationInitialized() throws URISyntaxException { + + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + any()); + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java new file mode 100644 index 000000000..b82d6eb2c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -0,0 +1,345 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Tests for error handling changes in HttpClientStreamableHttpTransport. Specifically + * tests the distinction between session-related errors and general transport errors for + * 404 and 400 status codes. + * + * @author Christian Tzolov + */ +@Timeout(15) +public class HttpClientStreamableHttpTransportErrorHandlingTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", httpExchange -> { + if ("DELETE".equals(httpExchange.getRequestMethod())) { + httpExchange.sendResponseHeaders(200, 0); + } + else if ("POST".equals(httpExchange.getRequestMethod())) { + // Capture session ID from request if present + String requestSessionId = httpExchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Set response headers + httpExchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Add session ID to response if configured + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + httpExchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + + // Send response based on configured status + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + httpExchange.sendResponseHeaders(200, response.length()); + httpExchange.getResponseBody().write(response.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + } + httpExchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = HttpClientStreamableHttpTransport.builder(HOST).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + } + + /** + * Test that 404 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test404WithoutSessionId() { + serverResponseStatus.set(404); + currentServerSessionId.set(null); // No session ID in response + + var testMessage = createTestRequestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException + */ + @Test + void test404WithSessionId() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-123"); + + // Set up exception handler to verify session invalidation + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // The session should now be established, next request will include session ID + // Now return 404 for next request + serverResponseStatus.set(404); + + // Send another message - should get SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Verify exception handler was called with SessionNotFoundException + verify(exceptionHandler).accept(any(McpTransportSessionNotFoundException.class)); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 400 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test400WithoutSessionId() { + serverResponseStatus.set(400); + currentServerSessionId.set(null); // No session ID + + var testMessage = createTestRequestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException + * This handles the case mentioned in the code comment about some implementations + * returning 400 for unknown session IDs. + */ + @Test + void test400WithSessionId() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-456"); + + // Set up exception handler + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // The session should now be established, next request will include session ID + // Now return 400 for next request (simulating unknown session ID) + serverResponseStatus.set(400); + + // Send another message - should get SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Verify exception handler was called + verify(exceptionHandler).accept(any(McpTransportSessionNotFoundException.class)); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test session recovery after SessionNotFoundException Verifies that a new session + * can be established after the old one is invalidated + */ + @Test + void testSessionRecoveryAfter404() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("session-1"); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(lastReceivedSessionId.get()).isNull(); + + // The session should now be established + // Simulate session loss - return 404 + serverResponseStatus.set(404); + + // This should fail with SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Now server is back with new session + serverResponseStatus.set(200); + currentServerSessionId.set("session-2"); + lastReceivedSessionId.set(null); // Reset to verify new session + + // Should be able to establish new session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Verify no session ID was sent (since old session was invalidated) + assertThat(lastReceivedSessionId.get()).isNull(); + + // Next request should use the new session ID + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Session ID should now be sent with requests + assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that reconnect (GET request) also properly handles 404/400 errors + */ + @Test + void testReconnectErrorHandling() { + + // Set up SSE endpoint for GET requests + server.createContext("/mcp-sse", exchange -> { + String method = exchange.getRequestMethod(); + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if ("GET".equals(method)) { + int status = serverResponseStatus.get(); + + if (status == 404 && requestSessionId != null) { + // 404 with session ID - should trigger SessionNotFoundException + exchange.sendResponseHeaders(404, 0); + } + else if (status == 404) { + // 404 without session ID - should trigger McpTransportException + exchange.sendResponseHeaders(404, 0); + } + else { + // Normal SSE response + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + // Send a test SSE event + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + exchange.getResponseBody().write(sseData.getBytes()); + } + } + else { + // POST request handling + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + exchange.close(); + }); + + // Test with session ID - should get SessionNotFoundException + serverResponseStatus.set(200); + currentServerSessionId.set("sse-session-1"); + + var transport = HttpClientStreamableHttpTransport.builder(HOST) + .endpoint("/mcp-sse") + .openConnectionOnStartup(true) // This will trigger GET request on connect + .build(); + + // First connect successfully + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Now simulate server returning 404 on reconnect + serverResponseStatus.set(404); + + // This should trigger reconnect which will fail + // The error should be handled internally and passed to exception handler + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + private McpSchema.JSONRPCRequest createTestRequestMessage() { + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0")); + return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", + initializeRequest); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..2ade30e17 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java @@ -0,0 +1,163 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the {@link HttpClientStreamableHttpTransport} class. + * + * @author Daniel Garnier-Moiroux + */ +class HttpClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + private McpTransportContext context = McpTransportContext + .create(Map.of("test-transport-context-key", "some-value")); + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + void withTransport(HttpClientStreamableHttpTransport transport, Consumer c) { + try { + c.accept(transport); + } + finally { + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + } + + @Test + void testRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); + }); + } + + @Test + void testAsyncRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockRequestCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .asyncHttpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); + }); + } + + @Test + void testCloseUninitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..a04787aa3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingMcpAsyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpAsyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + StepVerifier + .create(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context)) + .expectNext(TEST_BUILDER) + .verifyComplete(); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "one")), + (builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "two")))); + + var headers = Mono + .from(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", + McpTransportContext.EMPTY)) + .map(HttpRequest.Builder::build) + .map(HttpRequest::headers) + .flatMapIterable(h -> h.allValues("x-test")); + + StepVerifier.create(headers).expectNext("one").expectNext("two").verifyComplete(); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..6c51a3d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DelegatingMcpSyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpSyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = Mockito.mock(McpSyncHttpClientRequestCustomizer.class); + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var testHeaderName = "x-test"; + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> builder.header(testHeaderName, "one"), + (builder, method, uri, body, ctx) -> builder.header(testHeaderName, "two"))); + + customizer.customize(TEST_BUILDER, "GET", TEST_URI, null, McpTransportContext.EMPTY); + var request = TEST_BUILDER.build(); + + assertThat(request.headers().allValues(testHeaderName)).containsExactly("one", "two"); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..8b2dea462 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers. + * + *

    + * This test class validates the end-to-end flow of transport context propagation in MCP + * communication, demonstrating how contextual information can be passed from client to + * server through HTTP headers and accessed within server-side handlers. + * + *

    Test Scenarios

    + *

    + * The tests cover multiple transport configurations with async servers: + *

      + *
    • Stateless server with async streamable HTTP clients
    • + *
    • Streamable server with async streamable HTTP clients
    • + *
    • SSE (Server-Sent Events) server with async SSE clients
    • + *
    + * + *

    Context Propagation Flow

    + *
      + *
    1. Client-side: Context data is stored in the Reactor Context and injected into HTTP + * headers via {@link McpSyncHttpClientRequestCustomizer}
    2. + *
    3. Transport: The context travels as HTTP headers (specifically "x-test" header in + * these tests)
    4. + *
    5. Server-side: A {@link McpTransportContextExtractor} extracts the header value and + * makes it available to request handlers through {@link McpTransportContext}
    6. + *
    7. Verification: The server echoes back the received context value as the tool call + * result
    8. + *
    + * + *

    + * All tests use an embedded Tomcat server running on a dynamically allocated port to + * ensure isolation and prevent port conflicts during parallel test execution. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final String HEADER_NAME = "x-test"; + + private final McpAsyncHttpClientRequestCustomizer asyncClientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + return Mono.just(builder); + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpAsyncClient asyncStreamableClient = McpClient + .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopTomcat(); + } + + @Test + void asyncClinetStatelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..8efb6a960 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.McpTestRequestRecordingServletFilter; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class HttpClientStreamableHttpVersionNegotiationIntegrationTests { + + private Tomcat tomcat; + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private final McpTestRequestRecordingServletFilter requestRecordingFilter = new McpTestRequestRecordingServletFilter(); + + private final HttpServletStreamableServerTransportProvider transport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor( + req -> McpTransportContext.create(Map.of("protocol-version", req.getHeader("MCP-protocol-version")))) + .build(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> new McpSchema.CallToolResult( + exchange.transportContext().get("protocol-version").toString(), null); + + McpSyncServer mcpServer = McpServer.sync(transport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler)) + .build(); + + @AfterEach + void tearDown() { + stopTomcat(); + } + + @Test + void usesLatestVersion() { + startTomcat(); + + var client = McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT).build()) + .build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + startTomcat(); + + var transport = HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls).filteredOn(c -> c.method().equals("POST") && !c.body().contains("\"method\":\"initialize\"")) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + private void startTomcat() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport, requestRecordingFilter); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..cc8f4c4be --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,243 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test both Client and Server {@link McpTransportContext} integration, in two steps. + *

    + * First, the client calls a tool and writes data stored in a thread-local to an HTTP + * header using {@link SyncSpec#transportContextProvider(Supplier)} and + * {@link McpSyncHttpClientRequestCustomizer}. + *

    + * Then the server reads the header with a {@link McpTransportContextExtractor} and + * returns the value as the result of the tool call. + * + * @author Daniel Garnier-Moiroux + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java new file mode 100644 index 000000000..090710248 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -0,0 +1,722 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @Test + @Deprecated + void testAddTool() { + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddToolCall() { + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + @Deprecated + void testAddDuplicateTool() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + List specs = List.of( + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); + } + + @Test + void testRemoveTool() { + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyResourcesUpdated() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier + .create(mcpAsyncServer + .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 000000000..1f5387f37 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1756 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Utils; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + finally { + server.closeGracefully().block(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without elicitation capabilities + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccessWithTransportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var transportContextIsNull = new AtomicBoolean(false); + var transportContextIsEmpty = new AtomicBoolean(false); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + transportContextIsNull.set(transportContext == null); + transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); + String ctxValue = (String) transportContext.get("important"); + + try { + String responseBody = "TOOL RESPONSE"; + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(transportContextIsNull.get()).isFalse(); + assertThat(transportContextIsEmpty.get()).isFalse(); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> toolsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + toolsRef.set(toolsUpdate); + } + catch (Exception e) { + e.printStackTrace(); + } + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(toolsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @MethodSource("clientsForTesting") + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java new file mode 100644 index 000000000..915c658e3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -0,0 +1,678 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpServerTransportProvider} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpSyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + @Test + @Deprecated + void testAddTool() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddToolCall() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + @Deprecated + void testAddDuplicateTool() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + List specs = List.of( + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); + } + + @Test + void testRemoveTool() { + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testNotifyResourcesUpdated() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer + .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecification() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Resource must not be null"); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, (exchange, roots) -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java new file mode 100644 index 000000000..62332fcdb --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link McpServerFeatures.AsyncToolSpecification.Builder}. + * + * @author Christian Tzolov + */ +class AsyncToolSpecificationBuilderTest { + + @Test + void builderShouldCreateValidAsyncToolSpecification() { + + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of(new TextContent("Test result"))).isError(false).build())) + .build(); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.callHandler()).isNotNull(); + assertThat(specification.call()).isNull(); // deprecated field should be null + } + + @Test + void builderShouldThrowExceptionWhenToolIsNull() { + assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder() + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); + } + + @Test + void builderShouldThrowExceptionWhenCallToolIsNull() { + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder().tool(tool).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Call handler function must not be null"); + } + + @Test + void builderShouldAllowMethodChaining() { + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + McpServerFeatures.AsyncToolSpecification.Builder builder = McpServerFeatures.AsyncToolSpecification.builder(); + + // Then - verify method chaining returns the same builder instance + assertThat(builder.tool(tool)).isSameAs(builder); + assertThat(builder.callHandler( + (exchange, request) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build()))) + .isSameAs(builder); + } + + @Test + void builtSpecificationShouldExecuteCallToolCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("calculator") + .title("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "42"; + + McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> { + return Mono.just(CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()); + }) + .build(); + + CallToolRequest request = new CallToolRequest("calculator", Map.of()); + Mono resultMono = specification.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + @SuppressWarnings("deprecation") + void deprecatedConstructorShouldWorkCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("deprecated-tool") + .title("A deprecated tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "deprecated result"; + + // Test the deprecated constructor that takes a 'call' function + McpServerFeatures.AsyncToolSpecification specification = new McpServerFeatures.AsyncToolSpecification(tool, + (exchange, + arguments) -> Mono.just(CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build())); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.call()).isNotNull(); // deprecated field should be set + assertThat(specification.callHandler()).isNotNull(); // should be automatically + // created + + // Test that the callTool function works (it should delegate to the call function) + CallToolRequest request = new CallToolRequest("deprecated-tool", Map.of("arg1", "value1")); + Mono resultMono = specification.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + + // Test that the deprecated call function also works directly + Mono callResultMono = specification.call().apply(null, request.arguments()); + + StepVerifier.create(callResultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + void fromSyncShouldConvertSyncToolSpecificationCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("sync-tool") + .title("A sync tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "sync result"; + + // Create a sync tool specification + McpServerFeatures.SyncToolSpecification syncSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()) + .build(); + + // Convert to async using fromSync + McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification + .fromSync(syncSpec); + + assertThat(asyncSpec).isNotNull(); + assertThat(asyncSpec.tool()).isEqualTo(tool); + assertThat(asyncSpec.callHandler()).isNotNull(); + assertThat(asyncSpec.call()).isNull(); // should be null since sync spec doesn't + // have deprecated call + + // Test that the converted async specification works correctly + CallToolRequest request = new CallToolRequest("sync-tool", Map.of("param", "value")); + Mono resultMono = asyncSpec.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + @SuppressWarnings("deprecation") + void fromSyncShouldConvertSyncToolSpecificationWithDeprecatedCallCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("sync-deprecated-tool") + .title("A sync tool with deprecated call") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "sync deprecated result"; + McpAsyncServerExchange nullExchange = null; // Mock or create a suitable exchange + // if needed + + // Create a sync tool specification using the deprecated constructor + McpServerFeatures.SyncToolSpecification syncSpec = new McpServerFeatures.SyncToolSpecification(tool, + (exchange, arguments) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()); + + // Convert to async using fromSync + McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification + .fromSync(syncSpec); + + assertThat(asyncSpec).isNotNull(); + assertThat(asyncSpec.tool()).isEqualTo(tool); + assertThat(asyncSpec.callHandler()).isNotNull(); + assertThat(asyncSpec.call()).isNotNull(); // should be set since sync spec has + // deprecated call + + // Test that the converted async specification works correctly via callTool + CallToolRequest request = new CallToolRequest("sync-deprecated-tool", Map.of("param", "value")); + Mono resultMono = asyncSpec.callHandler().apply(nullExchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + + // Test that the deprecated call function also works + Mono callResultMono = asyncSpec.call().apply(nullExchange, request.arguments()); + + StepVerifier.create(callResultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + void fromSyncShouldReturnNullWhenSyncSpecIsNull() { + assertThat(McpServerFeatures.AsyncToolSpecification.fromSync(null)).isNull(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java new file mode 100644 index 000000000..d2b9d14d0 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import static org.assertj.core.api.Assertions.assertThat; + +@Timeout(15) +class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java new file mode 100644 index 000000000..491c2d4ed --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -0,0 +1,653 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ProtocolVersions; +import net.javacrumbs.jsonunit.core.Option; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.client.RestClient; + +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +@Timeout(15) +class HttpServletStatelessIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletStatelessServerTransport mcpStatelessServerTransport; + + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + private Tomcat tomcat; + + @BeforeEach + public void before() { + this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpStatelessServerTransport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + } + + @AfterEach + public void after() { + if (mcpStatelessServerTransport != null) { + mcpStatelessServerTransport.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("CALL RESPONSE"))) + .isError(false) + .build(); + McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( + Tool.builder().name("tool1").title("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build(), + (transportContext, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport).build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (transportContext, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpStatelessServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (transportContext, getPromptRequest) -> null)) + .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .instructions("bla") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + dynamicTool, (transportContext, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.close(); + } + } + + @Test + void testThrownMcpErrorAndJsonRpcError() throws Exception { + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool testTool = Tool.builder().name("test").description("test").build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + testTool, (transportContext, request) -> { + throw new RuntimeException("testing"); + }); + + mcpServer.addTool(toolSpec); + + McpSchema.CallToolRequest callToolRequest = new McpSchema.CallToolRequest("test", Map.of()); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_TOOLS_CALL, "test", callToolRequest); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", CUSTOM_MESSAGE_ENDPOINT); + MockHttpServletResponse response = new MockHttpServletResponse(); + + byte[] content = JSON_MAPPER.writeValueAsBytes(jsonrpcRequest); + request.setContent(content); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM); + request.addHeader("Content-Type", APPLICATION_JSON); + request.addHeader("Cache-Control", "no-cache"); + request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_03_26); + + mcpStatelessServerTransport.service(request, response); + + McpSchema.JSONRPCResponse jsonrpcResponse = JSON_MAPPER.readValue(response.getContentAsByteArray(), + McpSchema.JSONRPCResponse.class); + + assertThat(jsonrpcResponse).isNotNull(); + assertThat(jsonrpcResponse.error()).isNotNull(); + assertThat(jsonrpcResponse.error().code()).isEqualTo(ErrorCodes.INTERNAL_ERROR); + assertThat(jsonrpcResponse.error().message()).isEqualTo("testing"); + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java new file mode 100644 index 000000000..96f1524b7 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpServletStreamableAsyncServerTests extends AbstractMcpAsyncServerTests { + + protected McpStreamableServerTransportProvider createMcpTransportProvider() { + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java new file mode 100644 index 000000000..81423e0c5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import static org.assertj.core.api.Assertions.assertThat; + +@Timeout(15) +class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private HttpServletStreamableServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .mcpEndpoint(MESSAGE_ENDPOINT) + .keepAliveInterval(Duration.ofSeconds(1)) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java new file mode 100644 index 000000000..87c0712dc --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpServletStreamableSyncServerTests extends AbstractMcpSyncServerTests { + + protected McpStreamableServerTransportProvider createMcpTransportProvider() { + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java new file mode 100644 index 000000000..640d34c9c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -0,0 +1,698 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link McpAsyncServerExchange}. + * + * @author Christian Tzolov + */ +class McpAsyncServerExchangeTests { + + @Mock + private McpServerSession mockSession; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + private McpAsyncServerExchange exchange; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + clientCapabilities = McpSchema.ClientCapabilities.builder().roots(true).build(); + + clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + + exchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, + McpTransportContext.EMPTY); + } + + @Test + void testListRootsWithSinglePage() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(singlePageResult)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + assertThat(result.roots()).hasSize(2); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(0).name()).isEqualTo("Project 1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(1).name()).isEqualTo("Project 2"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListRootsWithMultiplePages() { + + List page1Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + List page2Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + assertThat(result.roots()).hasSize(3); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(2).uri()).isEqualTo("file:///home/user/project3"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListRootsWithEmptyResult() { + + McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(emptyResult)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + assertThat(result.roots()).isEmpty(); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListRootsWithSpecificCursor() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), + any(TypeRef.class))) + .thenReturn(Mono.just(result)); + + StepVerifier.create(exchange.listRoots("someCursor")).assertNext(listResult -> { + assertThat(listResult.roots()).hasSize(1); + assertThat(listResult.roots().get(0).uri()).isEqualTo("file:///home/user/project3"); + assertThat(listResult.nextCursor()).isEqualTo("nextCursor"); + }).verifyComplete(); + } + + @Test + void testListRootsWithError() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Network error"))); + + // When & Then + StepVerifier.create(exchange.listRoots()).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Network error"); + }); + } + + @Test + void testListRootsUnmodifiabilityAfterAccumulation() { + + List page1Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"))); + List page2Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project2", "Project 2"))); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + // Verify the accumulated result is correct + assertThat(result.roots()).hasSize(2); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + + // Verify that clear() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().clear()).isInstanceOf(UnsupportedOperationException.class); + + // Verify that remove() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().remove(0)).isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testGetClientCapabilities() { + assertThat(exchange.getClientCapabilities()).isEqualTo(clientCapabilities); + } + + @Test + void testGetClientInfo() { + assertThat(exchange.getClientInfo()).isEqualTo(clientInfo); + } + + // --------------------------------------- + // Logging Notification Tests + // --------------------------------------- + + @Test + void testLoggingNotificationWithNullMessage() { + StepVerifier.create(exchange.loggingNotification(null)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Logging message must not be null"); + }); + } + + @Test + void testSetMinLoggingLevelWithNullValue() { + assertThatThrownBy(() -> exchange.setMinLoggingLevel(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("minLoggingLevel must not be null"); + } + + @Test + void testLoggingNotificationWithAllowedLevel() { + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.empty()); + + StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); + } + + @Test + void testLoggingNotificationWithFilteredLevel() { + exchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(eq(McpSchema.LoggingLevel.DEBUG)); + + McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); + + StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); + + McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.WARNING) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + StepVerifier.create(exchange.loggingNotification(warningNotification)).verifyComplete(); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.WARNING)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(warningNotification)); + } + + @Test + void testLoggingNotificationWithSessionError() { + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.error(new RuntimeException("Session error"))); + + StepVerifier.create(exchange.loggingNotification(notification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session error"); + }); + } + + // --------------------------------------- + // Create Elicitation Tests + // --------------------------------------- + + @Test + void testCreateElicitationWithNullCapabilities() { + // Given - Create exchange with null capabilities + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + StepVerifier.create(exchangeWithNullCapabilities.createElicitation(elicitRequest)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + // Given - Create exchange without elicitation capabilities + McpSchema.ClientCapabilities capabilitiesWithoutElicitation = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange exchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + StepVerifier.create(exchangeWithoutElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + }); + + // Verify that sendRequest was never called due to missing elicitation + // capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithComplexRequest() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + // Create a complex elicit request with schema + java.util.Map requestedSchema = new java.util.HashMap<>(); + requestedSchema.put("type", "object"); + requestedSchema.put("properties", java.util.Map.of("name", java.util.Map.of("type", "string"), "age", + java.util.Map.of("type", "number"))); + requestedSchema.put("required", java.util.List.of("name")); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your personal information") + .requestedSchema(requestedSchema) + .build(); + + java.util.Map responseContent = new java.util.HashMap<>(); + responseContent.put("name", "John Doe"); + responseContent.put("age", 30); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(responseContent) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isNotNull(); + assertThat(result.content().get("name")).isEqualTo("John Doe"); + assertThat(result.content().get("age")).isEqualTo(30); + }).verifyComplete(); + } + + @Test + void testCreateElicitationWithDeclineAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide sensitive information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.DECLINE) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.DECLINE); + }).verifyComplete(); + } + + @Test + void testCreateElicitationWithCancelAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.CANCEL) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.CANCEL); + }).verifyComplete(); + } + + @Test + void testCreateElicitationWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session communication error"); + }); + } + + // --------------------------------------- + // Create Message Tests + // --------------------------------------- + + @Test + void testCreateMessageWithNullCapabilities() { + + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + StepVerifier.create(exchangeWithNullCapabilities.createMessage(createMessageRequest)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + McpSchema.ClientCapabilities capabilitiesWithoutSampling = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange exchangeWithoutSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutSampling, clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + StepVerifier.create(exchangeWithoutSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + + // Verify that sendRequest was never called due to missing sampling capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithBasicRequest() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Hello! How can I help you today?")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Hello! How can I help you today?"); + assertThat(result.model()).isEqualTo("gpt-4"); + assertThat(result.stopReason()).isEqualTo(McpSchema.CreateMessageResult.StopReason.END_TURN); + }).verifyComplete(); + } + + @Test + void testCreateMessageWithImageContent() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + // Create request with image content + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.ImageContent(null, "...", + "image/jpeg")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("I can see an image. It appears to be a photograph.")) + .model("gpt-4-vision") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.model()).isEqualTo("gpt-4-vision"); + }).verifyComplete(); + } + + @Test + void testCreateMessageWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello")))) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session communication error"); + }); + } + + @Test + void testCreateMessageWithIncludeContext() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("What files are available?")))) + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.ALL_SERVERS) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Based on the available context, I can see several files...")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(((McpSchema.TextContent) result.content()).text()).contains("context"); + }).verifyComplete(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + + @Test + void testPingWithSuccessfulResponse() { + + java.util.Map expectedResponse = java.util.Map.of(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResponse)); + + StepVerifier.create(exchange.ping()).assertNext(result -> { + assertThat(result).isEqualTo(expectedResponse); + assertThat(result).isInstanceOf(java.util.Map.class); + }).verifyComplete(); + + // Verify that sendRequest was called with correct parameters + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingWithMcpError() { + // Given - Mock an MCP-specific error during ping + McpError mcpError = new McpError("Server unavailable"); + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.error(mcpError)); + + // When & Then + StepVerifier.create(exchange.ping()).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Server unavailable"); + }); + + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingMultipleCalls() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(Map.of())) + .thenReturn(Mono.just(Map.of())); + + // First call + StepVerifier.create(exchange.ping()).assertNext(result -> { + assertThat(result).isInstanceOf(Map.class); + }).verifyComplete(); + + // Second call + StepVerifier.create(exchange.ping()).assertNext(result -> { + assertThat(result).isInstanceOf(Map.class); + }).verifyComplete(); + + // Verify that sendRequest was called twice + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java new file mode 100644 index 000000000..54fb80a78 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -0,0 +1,327 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; + +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpError; + +/** + * Tests for completion functionality with context support. + * + * @author Surbhi Bansal + */ +class McpCompletionTests { + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + // Create and con figure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + e.printStackTrace(); + } + } + } + + @Test + void testCompletionHandlerReceivesContext() { + AtomicReference receivedRequest = new AtomicReference<>(); + BiFunction completionHandler = (exchange, request) -> { + receivedRequest.set(request); + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test-completion"), 1, false)); + }; + + ResourceReference resourceRef = new ResourceReference(ResourceReference.TYPE, "test://resource/{param}"); + + var resource = Resource.builder() + .uri("test://resource/{param}") + .name("Test Resource") + .description("A resource for testing") + .mimeType("text/plain") + .size(123L) + .build(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification(resourceRef, completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Test with context + CompleteRequest request = new CompleteRequest(resourceRef, + new CompleteRequest.CompleteArgument("param", "test"), null, + new CompleteRequest.CompleteContext(Map.of("previous", "value"))); + + CompleteResult result = mcpClient.completeCompletion(request); + + // Verify handler received the context + assertThat(receivedRequest.get().context()).isNotNull(); + assertThat(receivedRequest.get().context().arguments()).containsEntry("previous", "value"); + assertThat(result.completion().values()).containsExactly("test-completion"); + } + + mcpServer.close(); + } + + @Test + void testCompletionBackwardCompatibility() { + AtomicReference contextWasNull = new AtomicReference<>(false); + BiFunction completionHandler = (exchange, request) -> { + contextWasNull.set(request.context() == null); + return new CompleteResult( + new CompleteResult.CompleteCompletion(List.of("no-context-completion"), 1, false)); + }; + + McpSchema.Prompt prompt = new Prompt("test-prompt", "this is a test prompt", + List.of(new PromptArgument("arg", "string", false))); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification(prompt, + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new PromptReference(PromptReference.TYPE, "test-prompt"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Test without context + CompleteRequest request = new CompleteRequest(new PromptReference(PromptReference.TYPE, "test-prompt"), + new CompleteRequest.CompleteArgument("arg", "val")); + + CompleteResult result = mcpClient.completeCompletion(request); + + // Verify context was null + assertThat(contextWasNull.get()).isTrue(); + assertThat(result.completion().values()).containsExactly("no-context-completion"); + } + + mcpServer.close(); + } + + @Test + void testDependentCompletionScenario() { + BiFunction completionHandler = (exchange, request) -> { + // Simulate database/table completion scenario + if (request.ref() instanceof ResourceReference resourceRef) { + if ("db://{database}/{table}".equals(resourceRef.uri())) { + if ("database".equals(request.argument().name())) { + // Complete database names + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users_db", "products_db", "analytics_db"), 3, false)); + } + else if ("table".equals(request.argument().name())) { + // Complete table names based on selected database + if (request.context() != null && request.context().arguments() != null) { + String db = request.context().arguments().get("database"); + if ("users_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users", "sessions", "permissions"), 3, false)); + } + else if ("products_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("products", "categories", "inventory"), 3, false)); + } + } + } + } + } + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + }; + + McpSchema.Resource resource = Resource.builder() + .uri("db://{database}/{table}") + .name("Database Table") + .description("Resource representing a table in a database") + .mimeType("application/json") + .size(456L) + .build(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // First, complete database + CompleteRequest dbRequest = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("database", "")); + + CompleteResult dbResult = mcpClient.completeCompletion(dbRequest); + assertThat(dbResult.completion().values()).contains("users_db", "products_db"); + + // Then complete table with database context + CompleteRequest tableRequest = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "users_db"))); + + CompleteResult tableResult = mcpClient.completeCompletion(tableRequest); + assertThat(tableResult.completion().values()).containsExactly("users", "sessions", "permissions"); + + // Different database gives different tables + CompleteRequest tableRequest2 = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "products_db"))); + + CompleteResult tableResult2 = mcpClient.completeCompletion(tableRequest2); + assertThat(tableResult2.completion().values()).containsExactly("products", "categories", "inventory"); + } + + mcpServer.close(); + } + + @Test + void testCompletionErrorOnMissingContext() { + BiFunction completionHandler = (exchange, request) -> { + if (request.ref() instanceof ResourceReference resourceRef) { + if ("db://{database}/{table}".equals(resourceRef.uri())) { + if ("table".equals(request.argument().name())) { + // Check if database context is provided + if (request.context() == null || request.context().arguments() == null + || !request.context().arguments().containsKey("database")) { + + throw McpError.builder(ErrorCodes.INVALID_REQUEST) + .message("Please select a database first to see available tables") + .build(); + } + // Normal completion if context is provided + String db = request.context().arguments().get("database"); + if ("test_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users", "orders", "products"), 3, false)); + } + } + } + } + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + }; + + McpSchema.Resource resource = Resource.builder() + .uri("db://{database}/{table}") + .name("Database Table") + .description("Resource representing a table in a database") + .mimeType("application/json") + .size(456L) + .build(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample" + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Try to complete table without database context - should raise error + CompleteRequest requestWithoutContext = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", "")); + + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.completeCompletion(requestWithoutContext)) + .withMessageContaining("Please select a database first"); + + // Now complete with proper context - should work normally + CompleteRequest requestWithContext = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "test_db"))); + + CompleteResult resultWithContext = mcpClient.completeCompletion(requestWithContext); + assertThat(resultWithContext.completion().values()).containsExactly("users", "orders", "products"); + } + + mcpServer.close(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java similarity index 61% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index 97358723f..cdd2bacb7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -7,7 +7,8 @@ import java.util.List; import java.util.UUID; -import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.MockMcpServerTransport; +import io.modelcontextprotocol.MockMcpServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; @@ -29,20 +30,24 @@ private McpSchema.JSONRPCRequest jsonRpcInitializeRequest(String requestId, Stri @Test void shouldUseLatestVersionByDefault() { - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); + transportProvider + .simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); assertThat(jsonResponse.result()).isInstanceOf(McpSchema.InitializeResult.class); McpSchema.InitializeResult result = (McpSchema.InitializeResult) jsonResponse.result(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + + var protocolVersion = transportProvider.protocolVersions().get(transportProvider.protocolVersions().size() - 1); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); server.closeGracefully().subscribe(); } @@ -50,16 +55,18 @@ void shouldUseLatestVersionByDefault() { @Test void shouldNegotiateSpecificVersion() { String oldVersion = "0.1.0"; - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); server.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, oldVersion)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, oldVersion)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); @@ -73,20 +80,23 @@ void shouldNegotiateSpecificVersion() { @Test void shouldSuggestLatestVersionForUnsupportedVersion() { String unsupportedVersion = "999.999.999"; - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, unsupportedVersion)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, unsupportedVersion)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); assertThat(jsonResponse.result()).isInstanceOf(McpSchema.InitializeResult.class); McpSchema.InitializeResult result = (McpSchema.InitializeResult) jsonResponse.result(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + var protocolVersion = transportProvider.protocolVersions().get(transportProvider.protocolVersions().size() - 1); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); server.closeGracefully().subscribe(); } @@ -97,15 +107,17 @@ void shouldUseHighestVersionWhenMultipleSupported() { String middleVersion = "0.2.0"; String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); server.setProtocolVersions(List.of(oldVersion, middleVersion, latestVersion)); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, latestVersion)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, latestVersion)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java new file mode 100644 index 000000000..069d0f896 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -0,0 +1,692 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link McpSyncServerExchange}. + * + * @author Christian Tzolov + */ +class McpSyncServerExchangeTests { + + @Mock + private McpServerSession mockSession; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + private McpAsyncServerExchange asyncExchange; + + private McpSyncServerExchange exchange; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + clientCapabilities = McpSchema.ClientCapabilities.builder().roots(true).build(); + + clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + + asyncExchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + exchange = new McpSyncServerExchange(asyncExchange); + } + + @Test + void testListRootsWithSinglePage() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(singlePageResult)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + assertThat(result.roots()).hasSize(2); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(0).name()).isEqualTo("Project 1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(1).name()).isEqualTo("Project 2"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testListRootsWithMultiplePages() { + + List page1Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + List page2Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + assertThat(result.roots()).hasSize(3); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(2).uri()).isEqualTo("file:///home/user/project3"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testListRootsWithEmptyResult() { + + McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(emptyResult)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + assertThat(result.roots()).isEmpty(); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testListRootsWithSpecificCursor() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), + any(TypeRef.class))) + .thenReturn(Mono.just(result)); + + McpSchema.ListRootsResult listResult = exchange.listRoots("someCursor"); + + assertThat(listResult.roots()).hasSize(1); + assertThat(listResult.roots().get(0).uri()).isEqualTo("file:///home/user/project3"); + assertThat(listResult.nextCursor()).isEqualTo("nextCursor"); + } + + @Test + void testListRootsWithError() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Network error"))); + + // When & Then + assertThatThrownBy(() -> exchange.listRoots()).isInstanceOf(RuntimeException.class).hasMessage("Network error"); + } + + @Test + void testListRootsUnmodifiabilityAfterAccumulation() { + + List page1Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"))); + List page2Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project2", "Project 2"))); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + // Verify the accumulated result is correct + assertThat(result.roots()).hasSize(2); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + + // Verify that clear() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().clear()).isInstanceOf(UnsupportedOperationException.class); + + // Verify that remove() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().remove(0)).isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testGetClientCapabilities() { + assertThat(exchange.getClientCapabilities()).isEqualTo(clientCapabilities); + } + + @Test + void testGetClientInfo() { + assertThat(exchange.getClientInfo()).isEqualTo(clientInfo); + } + + // --------------------------------------- + // Logging Notification Tests + // --------------------------------------- + + @Test + void testLoggingNotificationWithNullMessage() { + assertThatThrownBy(() -> exchange.loggingNotification(null)).isInstanceOf(McpError.class) + .hasMessage("Logging message must not be null"); + } + + @Test + void testLoggingNotificationWithAllowedLevel() { + + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.empty()); + + exchange.loggingNotification(notification); + + // Verify that sendNotification was called exactly once + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); + } + + @Test + void testLoggingNotificationWithFilteredLevel() { + asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + + McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); + + exchange.loggingNotification(debugNotification); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); + + McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.WARNING) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + exchange.loggingNotification(warningNotification); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.WARNING); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(warningNotification)); + } + + @Test + void testLoggingNotificationWithSessionError() { + + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.error(new RuntimeException("Session error"))); + + assertThatThrownBy(() -> exchange.loggingNotification(notification)).isInstanceOf(RuntimeException.class) + .hasMessage("Session error"); + } + + // --------------------------------------- + // Create Elicitation Tests + // --------------------------------------- + + @Test + void testCreateElicitationWithNullCapabilities() { + // Given - Create exchange with null capabilities + McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, + clientInfo); + McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( + asyncExchangeWithNullCapabilities); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + assertThatThrownBy(() -> exchangeWithNullCapabilities.createElicitation(elicitRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + // Given - Create exchange without elicitation capabilities + McpSchema.ClientCapabilities capabilitiesWithoutElicitation = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange asyncExchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutElicitation, clientInfo); + McpSyncServerExchange exchangeWithoutElicitation = new McpSyncServerExchange(asyncExchangeWithoutElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + assertThatThrownBy(() -> exchangeWithoutElicitation.createElicitation(elicitRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + + // Verify that sendRequest was never called due to missing elicitation + // capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithComplexRequest() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + // Create a complex elicit request with schema + java.util.Map requestedSchema = new java.util.HashMap<>(); + requestedSchema.put("type", "object"); + requestedSchema.put("properties", java.util.Map.of("name", java.util.Map.of("type", "string"), "age", + java.util.Map.of("type", "number"))); + requestedSchema.put("required", java.util.List.of("name")); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your personal information") + .requestedSchema(requestedSchema) + .build(); + + java.util.Map responseContent = new java.util.HashMap<>(); + responseContent.put("name", "John Doe"); + responseContent.put("age", 30); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(responseContent) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isNotNull(); + assertThat(result.content().get("name")).isEqualTo("John Doe"); + assertThat(result.content().get("age")).isEqualTo(30); + } + + @Test + void testCreateElicitationWithDeclineAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide sensitive information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.DECLINE) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.DECLINE); + } + + @Test + void testCreateElicitationWithCancelAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.CANCEL) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.CANCEL); + } + + @Test + void testCreateElicitationWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + assertThatThrownBy(() -> exchangeWithElicitation.createElicitation(elicitRequest)) + .isInstanceOf(RuntimeException.class) + .hasMessage("Session communication error"); + } + + // --------------------------------------- + // Create Message Tests + // --------------------------------------- + + @Test + void testCreateMessageWithNullCapabilities() { + + McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, + clientInfo); + McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( + asyncExchangeWithNullCapabilities); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + assertThatThrownBy(() -> exchangeWithNullCapabilities.createMessage(createMessageRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + McpSchema.ClientCapabilities capabilitiesWithoutSampling = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange asyncExchangeWithoutSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutSampling, clientInfo); + McpSyncServerExchange exchangeWithoutSampling = new McpSyncServerExchange(asyncExchangeWithoutSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + assertThatThrownBy(() -> exchangeWithoutSampling.createMessage(createMessageRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + + // Verify that sendRequest was never called due to missing sampling capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithBasicRequest() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Hello! How can I help you today?")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Hello! How can I help you today?"); + assertThat(result.model()).isEqualTo("gpt-4"); + assertThat(result.stopReason()).isEqualTo(McpSchema.CreateMessageResult.StopReason.END_TURN); + } + + @Test + void testCreateMessageWithImageContent() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + // Create request with image content + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.ImageContent(null, "...", + "image/jpeg")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("I can see an image. It appears to be a photograph.")) + .model("gpt-4-vision") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.model()).isEqualTo("gpt-4-vision"); + } + + @Test + void testCreateMessageWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello")))) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + assertThatThrownBy(() -> exchangeWithSampling.createMessage(createMessageRequest)) + .isInstanceOf(RuntimeException.class) + .hasMessage("Session communication error"); + } + + @Test + void testCreateMessageWithIncludeContext() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("What files are available?")))) + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.ALL_SERVERS) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Based on the available context, I can see several files...")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(((McpSchema.TextContent) result.content()).text()).contains("context"); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + + @Test + void testPingWithSuccessfulResponse() { + + java.util.Map expectedResponse = java.util.Map.of(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResponse)); + + exchange.ping(); + + // Verify that sendRequest was called with correct parameters + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingWithMcpError() { + // Given - Mock an MCP-specific error during ping + McpError mcpError = new McpError("Server unavailable"); + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.error(mcpError)); + + // When & Then + assertThatThrownBy(() -> exchange.ping()).isInstanceOf(McpError.class).hasMessage("Server unavailable"); + + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingMultipleCalls() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(Map.of())) + .thenReturn(Mono.just(Map.of())); + + // First call + exchange.ping(); + + // Second call + exchange.ping(); + + // Verify that sendRequest was called twice + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java new file mode 100644 index 000000000..61703c306 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test to verify the separation of regular resources and resource templates. Regular + * resources (without template parameters) should only appear in resources/list. Template + * resources (containing {}) should only appear in resources/templates/list. + */ +public class ResourceTemplateListingTest { + + @Test + void testTemplateResourcesFilteredFromRegularListing() { + // The change we made filters resources containing "{" from the regular listing + // This test verifies that behavior is working correctly + + // Given a string with template parameter + String templateUri = "file:///test/{userId}/profile.txt"; + assertThat(templateUri.contains("{")).isTrue(); + + // And a regular URI + String regularUri = "file:///test/regular.txt"; + assertThat(regularUri.contains("{")).isFalse(); + + // The filter should exclude template URIs + assertThat(!templateUri.contains("{")).isFalse(); + assertThat(!regularUri.contains("{")).isTrue(); + } + + @Test + void testResourceListingWithMixedResources() { + // Create resource list with both regular and template resources + List allResources = List.of( + new McpSchema.Resource("file:///test/doc1.txt", "Document 1", "text/plain", null, null), + new McpSchema.Resource("file:///test/doc2.txt", "Document 2", "text/plain", null, null), + new McpSchema.Resource("file:///test/{type}/document.txt", "Typed Document", "text/plain", null, null), + new McpSchema.Resource("file:///users/{userId}/files/{fileId}", "User File", "text/plain", null, null)); + + // Apply the filter logic from McpAsyncServer line 438 + List filteredResources = allResources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Verify only regular resources are included + assertThat(filteredResources).hasSize(2); + assertThat(filteredResources).extracting(McpSchema.Resource::uri) + .containsExactlyInAnyOrder("file:///test/doc1.txt", "file:///test/doc2.txt"); + } + + @Test + void testResourceTemplatesListedSeparately() { + // Create mixed resources + List resources = List.of( + new McpSchema.Resource("file:///test/regular.txt", "Regular Resource", "text/plain", null, null), + new McpSchema.Resource("file:///test/user/{userId}/profile.txt", "User Profile", "text/plain", null, + null)); + + // Create explicit resource template + McpSchema.ResourceTemplate explicitTemplate = new McpSchema.ResourceTemplate( + "file:///test/document/{docId}/content.txt", "Document Template", null, "text/plain", null); + + // Filter regular resources (those without template parameters) + List regularResources = resources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Extract template resources (those with template parameters) + List templateResources = resources.stream() + .filter(resource -> resource.uri().contains("{")) + .map(resource -> new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations())) + .collect(Collectors.toList()); + + // Verify regular resources list + assertThat(regularResources).hasSize(1); + assertThat(regularResources.get(0).uri()).isEqualTo("file:///test/regular.txt"); + + // Verify template resources list includes both extracted and explicit templates + assertThat(templateResources).hasSize(1); + assertThat(templateResources.get(0).uriTemplate()).isEqualTo("file:///test/user/{userId}/profile.txt"); + + // In the actual implementation, both would be combined + List allTemplates = List.of(templateResources.get(0), explicitTemplate); + assertThat(allTemplates).hasSize(2); + assertThat(allTemplates).extracting(McpSchema.ResourceTemplate::uriTemplate) + .containsExactlyInAnyOrder("file:///test/user/{userId}/profile.txt", + "file:///test/document/{docId}/content.txt"); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java new file mode 100644 index 000000000..b7d46a967 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.MockMcpServerTransport; +import io.modelcontextprotocol.MockMcpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Test suite for Resource Template Management functionality. Tests the new + * addResourceTemplate() and removeResourceTemplate() methods, as well as the Map-based + * resource template storage. + * + * @author Christian Tzolov + */ +public class ResourceTemplateManagementTests { + + private static final String TEST_TEMPLATE_URI = "test://resource/{param}"; + + private static final String TEST_TEMPLATE_NAME = "test-template"; + + private MockMcpServerTransportProvider mockTransportProvider; + + private McpAsyncServer mcpAsyncServer; + + @BeforeEach + void setUp() { + mockTransportProvider = new MockMcpServerTransportProvider(new MockMcpServerTransport()); + } + + @AfterEach + void tearDown() { + if (mcpAsyncServer != null) { + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + } + + // --------------------------------------- + // Async Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).verifyComplete(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate(TEST_TEMPLATE_URI)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource template should complete successfully (no + // error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + } + + @Test + void testReplaceExistingResourceTemplate() { + ResourceTemplate originalTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Original template") + .mimeType("text/plain") + .build(); + + ResourceTemplate updatedTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Updated template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification originalSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + originalTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification updatedSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + updatedTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(originalSpec) + .build(); + + // Adding a resource template with the same URI should replace the existing one + StepVerifier.create(mcpAsyncServer.addResourceTemplate(updatedSpec)).verifyComplete(); + } + + // --------------------------------------- + // Sync Resource Template Tests + // --------------------------------------- + + @Test + void testSyncAddResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testSyncRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Map-based Storage Tests + // --------------------------------------- + + @Test + void testResourceTemplateMapBasedStorage() { + ResourceTemplate template1 = ResourceTemplate.builder() + .uriTemplate("test://template1/{id}") + .name("template1") + .description("First template") + .mimeType("text/plain") + .build(); + + ResourceTemplate template2 = ResourceTemplate.builder() + .uriTemplate("test://template2/{id}") + .name("template2") + .description("Second template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification spec1 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template1, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification spec2 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template2, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(spec1, spec2) + .build(); + + // Verify both templates are stored (this would be tested through integration + // tests + // or by accessing internal state, but for unit tests we verify no exceptions) + assertThat(mcpAsyncServer).isNotNull(); + } + + @Test + void testResourceTemplateBuilderWithMap() { + // Test that the new Map-based builder methods work correctly + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + // Test varargs builder method + assertThatCode(() -> { + McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build() + .closeGracefully() + .block(Duration.ofSeconds(10)); + }).doesNotThrowAnyException(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java similarity index 52% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 715f636de..8906adfe0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -4,22 +4,25 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + protected McpServerTransportProvider createMcpTransportProvider() { + return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); + } + @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java similarity index 52% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 208de7f74..7b77f9241 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -4,22 +4,25 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { + protected McpServerTransportProvider createMcpTransportProvider() { + return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); + } + @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java similarity index 53% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index e933d6382..b2dfbea25 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -4,10 +4,12 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. * @@ -16,9 +18,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(JSON_MAPPER); + } + @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java similarity index 51% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index d9350417f..c97c75d38 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -4,21 +4,27 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** - * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. + * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(JSON_MAPPER); + } + @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java new file mode 100644 index 000000000..9bcd2bc84 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; + +/** + * Tests for {@link McpServerFeatures.SyncToolSpecification.Builder}. + * + * @author Christian Tzolov + */ +class SyncToolSpecificationBuilderTest { + + @Test + void builderShouldCreateValidSyncToolSpecification() { + + Tool tool = Tool.builder().name("test-tool").title("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + + McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent("Test result"))) + .isError(false) + .build()) + .build(); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.callHandler()).isNotNull(); + assertThat(specification.call()).isNull(); // deprecated field should be null + } + + @Test + void builderShouldThrowExceptionWhenToolIsNull() { + assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder() + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); + } + + @Test + void builderShouldThrowExceptionWhenCallToolIsNull() { + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + + assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder().tool(tool).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("CallTool function must not be null"); + } + + @Test + void builderShouldAllowMethodChaining() { + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + McpServerFeatures.SyncToolSpecification.Builder builder = McpServerFeatures.SyncToolSpecification.builder(); + + // Then - verify method chaining returns the same builder instance + assertThat(builder.tool(tool)).isSameAs(builder); + assertThat(builder + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build())) + .isSameAs(builder); + } + + @Test + void builtSpecificationShouldExecuteCallToolCorrectly() { + Tool tool = Tool.builder() + .name("calculator") + .description("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "42"; + + McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> { + // Simple test implementation + return CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build(); + }) + .build(); + + CallToolRequest request = new CallToolRequest("calculator", Map.of()); + CallToolResult result = specification.callHandler().apply(null, request); + + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java new file mode 100644 index 000000000..be88097b3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class HttpServletSseServerCustomContextPathTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_CONTEXT_PATH = "/api/v1"; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, mcpServerTransportProvider); + + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT) + .build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + try (//@formatter:off + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) .build()) { //@formatter:on + + assertThat(client.initialize()).isNotNull(); + } + server.close(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java new file mode 100644 index 000000000..b94552d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; + +/** + * Simple {@link Filter} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingServletFilter implements Filter { + + private final List calls = new ArrayList<>(); + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + + if (servletRequest instanceof HttpServletRequest req) { + var headers = Collections.list(req.getHeaderNames()) + .stream() + .collect(Collectors.toUnmodifiableMap(Function.identity(), + name -> String.join(",", Collections.list(req.getHeaders(name))))); + var request = new CachedBodyHttpServletRequest(req); + calls.add(new Call(req.getMethod(), headers, request.getBodyAsString())); + filterChain.doFilter(request, servletResponse); + } + else { + filterChain.doFilter(servletRequest, servletResponse); + } + + } + + public List getCalls() { + + return List.copyOf(calls); + } + + public record Call(String method, Map headers, String body) { + + } + + public static class CachedBodyHttpServletRequest extends HttpServletRequestWrapper { + + private final byte[] cachedBody; + + public CachedBodyHttpServletRequest(HttpServletRequest request) throws IOException { + super(request); + this.cachedBody = request.getInputStream().readAllBytes(); + } + + @Override + public ServletInputStream getInputStream() { + return new CachedBodyServletInputStream(cachedBody); + } + + @Override + public BufferedReader getReader() { + return new BufferedReader(new InputStreamReader(getInputStream(), StandardCharsets.UTF_8)); + } + + public String getBodyAsString() { + return new String(cachedBody, StandardCharsets.UTF_8); + } + + } + + public static class CachedBodyServletInputStream extends ServletInputStream { + + private InputStream cachedBodyInputStream; + + public CachedBodyServletInputStream(byte[] cachedBody) { + this.cachedBodyInputStream = new ByteArrayInputStream(cachedBody); + } + + @Override + public boolean isFinished() { + try { + return cachedBodyInputStream.available() == 0; + } + catch (IOException e) { + e.printStackTrace(); + } + return false; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + throw new UnsupportedOperationException(); + } + + @Override + public int read() throws IOException { + return cachedBodyInputStream.read(); + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java new file mode 100644 index 000000000..6a70af33d --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -0,0 +1,223 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Disabled +class StdioServerTransportProviderTests { + + private final PrintStream originalOut = System.out; + + private final PrintStream originalErr = System.err; + + private ByteArrayOutputStream testErr; + + private PrintStream testOutPrintStream; + + private StdioServerTransportProvider transportProvider; + + private McpServerSession.Factory sessionFactory; + + private McpServerSession mockSession; + + @BeforeEach + void setUp() { + testErr = new ByteArrayOutputStream(); + + testOutPrintStream = new PrintStream(testErr, true); + System.setOut(testOutPrintStream); + System.setErr(testOutPrintStream); + + // Create mocks for session factory and session + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); + + // Configure mock behavior + when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); + + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, System.in, testOutPrintStream); + } + + @AfterEach + void tearDown() { + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (testOutPrintStream != null) { + testOutPrintStream.close(); + } + System.setOut(originalOut); + System.setErr(originalErr); + } + + @Test + void shouldCreateSessionWhenSessionFactoryIsSet() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Verify session was created with a transport + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleIncomingMessages() throws Exception { + + String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, stream, System.out); + // Set up a real session to capture the message + AtomicReference capturedMessage = new AtomicReference<>(); + CountDownLatch messageLatch = new CountDownLatch(1); + + McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession session = mock(McpServerSession.class); + when(session.handle(any())).thenAnswer(invocation -> { + capturedMessage.set(invocation.getArgument(0)); + messageLatch.countDown(); + return Mono.empty(); + }); + when(session.closeGracefully()).thenReturn(Mono.empty()); + return session; + }; + + // Set session factory + transportProvider.setSessionFactory(realSessionFactory); + + // Wait for the message to be processed using the latch + StepVerifier.create(Mono.fromCallable(() -> messageLatch.await(100, TimeUnit.SECONDS)).flatMap(success -> { + if (!success) { + return Mono.error(new AssertionError("Timeout waiting for message processing")); + } + return Mono.just(capturedMessage.get()); + })).assertNext(message -> { + assertThat(message).isNotNull(); + assertThat(message).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) message; + assertThat(request.method()).isEqualTo("test"); + assertThat(request.id()).isEqualTo(1); + }).verifyComplete(); + } + + @Test + void shouldNotifyClients() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Send notification + String method = "testNotification"; + Map params = Map.of("key", "value"); + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldCloseGracefully() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleMultipleCloseGracefullyCalls() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully multiple times + StepVerifier + .create(transportProvider.closeGracefully() + .then(transportProvider.closeGracefully()) + .then(transportProvider.closeGracefully())) + .verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleNotificationBeforeSessionFactoryIsSet() { + + transportProvider = new StdioServerTransportProvider(JSON_MAPPER); + // Send notification before setting session factory + StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + }); + } + + @Test + void shouldHandleInvalidJsonMessage() throws Exception { + + // Write an invalid JSON message to the input stream + String jsonMessage = "{invalid json}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, stream, testOutPrintStream); + + // Set up a session factory + transportProvider.setSessionFactory(sessionFactory); + + // Use StepVerifier with a timeout to wait for the error to be processed + StepVerifier + .create(Mono.delay(java.time.Duration.ofMillis(500)).then(Mono.fromCallable(() -> testErr.toString()))) + .assertNext(errorOutput -> assertThat(errorOutput).contains("Error processing inbound message")) + .verifyComplete(); + } + + @Test + void shouldHandleSessionClose() throws Exception { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close the transport provider + transportProvider.close(); + + // Verify session was closed + verify(mockSession).closeGracefully(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 000000000..490e29838 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,81 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +import jakarta.servlet.Filter; +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; +import org.apache.tomcat.util.descriptor.web.FilterDef; +import org.apache.tomcat.util.descriptor.web.FilterMap; + +/** + * @author Christian Tzolov + * @author Daniel Garnier-Moiroux + */ +public class TomcatTestUtil { + + TomcatTestUtil() { + // Prevent instantiation + } + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet, + Filter... additionalFilters) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + for (var filter : additionalFilters) { + var filterDef = new FilterDef(); + filterDef.setFilter(filter); + filterDef.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + context.addFilterDef(filterDef); + + var filterMap = new FilterMap(); + filterMap.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + filterMap.addURLPattern("/*"); + context.addFilterMap(filterMap); + } + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java new file mode 100644 index 000000000..a0bd568ef --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java @@ -0,0 +1,9 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +public class ArgumentException { + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java new file mode 100644 index 000000000..55f71fea4 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java @@ -0,0 +1,28 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.json.McpJsonMapper; +import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.util.Collections; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CompleteCompletionSerializationTest { + + @Test + void codeCompletionSerialization() throws IOException { + McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + McpSchema.CompleteResult.CompleteCompletion codeComplete = new McpSchema.CompleteResult.CompleteCompletion( + Collections.emptyList(), 0, false); + String json = jsonMapper.writeValueAsString(codeComplete); + String expected = """ + {"values":[],"total":0,"hasMore":false}"""; + assertEquals(expected, json, json); + + McpSchema.CompleteResult completeResult = new McpSchema.CompleteResult(codeComplete); + json = jsonMapper.writeValueAsString(completeResult); + expected = """ + {"completion":{"values":[],"total":0,"hasMore":false}}"""; + assertEquals(expected, json, json); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java new file mode 100644 index 000000000..fbe17d464 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for MCP-specific validation of JSONRPCRequest ID requirements. + * + * @author Christian Tzolov + */ +public class JSONRPCRequestMcpValidationTest { + + @Test + public void testValidStringId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", "string-id", null); + assertEquals("string-id", request.id()); + }); + } + + @Test + public void testValidIntegerId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", 123, null); + assertEquals(123, request.id()); + }); + } + + @Test + public void testValidLongId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", 123L, null); + assertEquals(123L, request.id()); + }); + } + + @Test + public void testNullIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", null, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST include an ID")); + assertTrue(exception.getMessage().contains("null IDs are not allowed")); + } + + @Test + public void testDoubleIdTypeThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", 123.45, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testBooleanIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", true, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testArrayIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", new String[] { "array" }, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testObjectIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", new Object(), null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java new file mode 100644 index 000000000..3de06f503 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -0,0 +1,313 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.function.Function; + +import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, + * request-response correlation, and notification processing. + * + * @author Christian Tzolov + */ +class McpClientSessionTests { + + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + private static final String TEST_METHOD = "test.method"; + + private static final String TEST_NOTIFICATION = "test.notification"; + + private static final String ECHO_METHOD = "echo"; + + TypeRef responseType = new TypeRef<>() { + }; + + @Test + void testSendRequest() { + String testParam = "test parameter"; + String responseData = "test response"; + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Create a Mono that will emit the response after the request is sent + Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); + // Verify response handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); + }).consumeNextWith(response -> { + // Verify the request was sent + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; + assertThat(request.method()).isEqualTo(TEST_METHOD); + assertThat(request.params()).isEqualTo(testParam); + assertThat(response).isEqualTo(responseData); + }).verifyComplete(); + + session.close(); + } + + @Test + void testSendRequestWithError() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify error handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + // Simulate error response + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }).expectError(McpError.class).verify(); + + session.close(); + } + + @Test + void testRequestTimeout() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify timeout + StepVerifier.create(responseMono) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(TIMEOUT.plusSeconds(1)); + + session.close(); + } + + @Test + void testSendNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Map params = Map.of("key", "value"); + Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); + + // Verify notification was sent + StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; + assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); + assertThat(notification.params()).isEqualTo(params); + }).verifyComplete(); + + session.close(); + } + + @Test + void testRequestHandling() { + String echoMessage = "Hello MCP!"; + Map> requestHandlers = Map.of(ECHO_METHOD, + params -> Mono.just(params)); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "test-id", echoMessage); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.result()).isEqualTo(echoMessage); + assertThat(response.error()).isNull(); + + session.close(); + } + + @Test + void testNotificationHandling() { + Sinks.One receivedParams = Sinks.one(); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))), + Function.identity()); + + // Simulate incoming notification from the server + Map notificationParams = Map.of("status", "ready"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + TEST_NOTIFICATION, notificationParams); + + transport.simulateIncomingMessage(notification); + + // Verify handler was called + assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); + + session.close(); + } + + @Test + void testUnknownMethodHandling() { + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Simulate incoming request for unknown method + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); + + session.close(); + } + + @Test + void testRequestHandlerThrowsMcpErrorWithJsonRpcError() { + // Setup: Create a request handler that throws McpError with custom error code and + // data + String testMethod = "test.customError"; + Map errorData = Map.of("customField", "customValue"); + McpClientSession.RequestHandler failingHandler = params -> Mono + .error(McpError.builder(123).message("Custom error message").data(errorData).build()); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain the custom error from McpError + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(123); + assertThat(response.error().message()).isEqualTo("Custom error message"); + assertThat(response.error().data()).isEqualTo(errorData); + + session.close(); + } + + @Test + void testRequestHandlerThrowsGenericException() { + // Setup: Create a request handler that throws a generic RuntimeException + String testMethod = "test.genericError"; + RuntimeException exception = new RuntimeException("Something went wrong"); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(exception); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with aggregated exception + // messages in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Something went wrong"); + // Verify data field contains aggregated exception messages + assertThat(response.error().data()).isNotNull(); + assertThat(response.error().data().toString()).contains("RuntimeException"); + assertThat(response.error().data().toString()).contains("Something went wrong"); + + session.close(); + } + + @Test + void testRequestHandlerThrowsExceptionWithCause() { + // Setup: Create a request handler that throws an exception with a cause chain + String testMethod = "test.chainedError"; + RuntimeException rootCause = new IllegalArgumentException("Root cause message"); + RuntimeException middleCause = new IllegalStateException("Middle cause message", rootCause); + RuntimeException topException = new RuntimeException("Top level message", middleCause); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(topException); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with full exception chain + // in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Top level message"); + // Verify data field contains the full exception chain + String dataString = response.error().data().toString(); + assertThat(dataString).contains("RuntimeException"); + assertThat(dataString).contains("Top level message"); + assertThat(dataString).contains("IllegalStateException"); + assertThat(dataString).contains("Middle cause message"); + assertThat(dataString).contains("IllegalArgumentException"); + assertThat(dataString).contains("Root cause message"); + + session.close(); + } + + @Test + void testGracefulShutdown() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + StepVerifier.create(session.closeGracefully()).verifyComplete(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java new file mode 100644 index 000000000..0978ffe0b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class McpErrorTest { + + @Test + void testNotFound() { + String uri = "file:///nonexistent.txt"; + McpError mcpError = McpError.RESOURCE_NOT_FOUND.apply(uri); + assertNotNull(mcpError.getJsonRpcError()); + assertEquals(-32002, mcpError.getJsonRpcError().code()); + assertEquals("Resource not found", mcpError.getJsonRpcError().message()); + assertEquals(Map.of("uri", uri), mcpError.getJsonRpcError().data()); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java new file mode 100644 index 000000000..6b0004cb9 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -0,0 +1,1765 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; + +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import net.javacrumbs.jsonunit.core.Option; + +/** + * @author Christian Tzolov + * @author Anurag Pant + */ +public class McpSchemaTests { + + // Content Types Tests + + @Test + void testTextContent() throws Exception { + McpSchema.TextContent test = new McpSchema.TextContent("XXX"); + String value = JSON_MAPPER.writeValueAsString(test); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"text","text":"XXX"}""")); + } + + @Test + void testTextContentDeserialization() throws Exception { + McpSchema.TextContent textContent = JSON_MAPPER.readValue(""" + {"type":"text","text":"XXX","_meta":{"metaKey":"metaValue"}}""", McpSchema.TextContent.class); + + assertThat(textContent).isNotNull(); + assertThat(textContent.type()).isEqualTo("text"); + assertThat(textContent.text()).isEqualTo("XXX"); + assertThat(textContent.meta()).containsKey("metaKey"); + } + + @Test + void testContentDeserializationWrongType() throws Exception { + + assertThatThrownBy(() -> JSON_MAPPER.readValue(""" + {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) + .isInstanceOf(InvalidTypeIdException.class) + .hasMessageContaining( + "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [audio, image, resource, resource_link, text]"); + } + + @Test + void testImageContent() throws Exception { + McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); + String value = JSON_MAPPER.writeValueAsString(test); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""")); + } + + @Test + void testImageContentDeserialization() throws Exception { + McpSchema.ImageContent imageContent = JSON_MAPPER.readValue(""" + {"type":"image","data":"base64encodeddata","mimeType":"image/png","_meta":{"metaKey":"metaValue"}}""", + McpSchema.ImageContent.class); + assertThat(imageContent).isNotNull(); + assertThat(imageContent.type()).isEqualTo("image"); + assertThat(imageContent.data()).isEqualTo("base64encodeddata"); + assertThat(imageContent.mimeType()).isEqualTo("image/png"); + assertThat(imageContent.meta()).containsKey("metaKey"); + } + + @Test + void testAudioContent() throws Exception { + McpSchema.AudioContent audioContent = new McpSchema.AudioContent(null, "base64encodeddata", "audio/wav"); + String value = JSON_MAPPER.writeValueAsString(audioContent); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav"}""")); + } + + @Test + void testAudioContentDeserialization() throws Exception { + McpSchema.AudioContent audioContent = JSON_MAPPER.readValue(""" + {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav","_meta":{"metaKey":"metaValue"}}""", + McpSchema.AudioContent.class); + assertThat(audioContent).isNotNull(); + assertThat(audioContent.type()).isEqualTo("audio"); + assertThat(audioContent.data()).isEqualTo("base64encodeddata"); + assertThat(audioContent.mimeType()).isEqualTo("audio/wav"); + assertThat(audioContent.meta()).containsKey("metaKey"); + } + + @Test + void testCreateMessageRequestWithMeta() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("User message"); + McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); + McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); + McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, + 0.7, 0.9); + + Map metadata = new HashMap<>(); + metadata.put("session", "test-session"); + + Map meta = new HashMap<>(); + meta.put("progressToken", "create-message-token-456"); + + McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() + .messages(Collections.singletonList(message)) + .modelPreferences(preferences) + .systemPrompt("You are a helpful assistant") + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) + .temperature(0.7) + .maxTokens(1000) + .stopSequences(Arrays.asList("STOP", "END")) + .metadata(metadata) + .meta(meta) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .containsEntry("_meta", Map.of("progressToken", "create-message-token-456")); + + // Test Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("create-message-token-456"); + } + + @Test + void testEmbeddedResource() throws Exception { + McpSchema.TextResourceContents resourceContents = new McpSchema.TextResourceContents("resource://test", + "text/plain", "Sample resource content"); + + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + + String value = JSON_MAPPER.writeValueAsString(test); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""")); + } + + @Test + void testEmbeddedResourceDeserialization() throws Exception { + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( + """ + {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"},"_meta":{"metaKey":"metaValue"}}""", + McpSchema.EmbeddedResource.class); + assertThat(embeddedResource).isNotNull(); + assertThat(embeddedResource.type()).isEqualTo("resource"); + assertThat(embeddedResource.resource()).isNotNull(); + assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); + assertThat(embeddedResource.resource().mimeType()).isEqualTo("text/plain"); + assertThat(((TextResourceContents) embeddedResource.resource()).text()).isEqualTo("Sample resource content"); + assertThat(embeddedResource.meta()).containsKey("metaKey"); + } + + @Test + void testEmbeddedResourceWithBlobContents() throws Exception { + McpSchema.BlobResourceContents resourceContents = new McpSchema.BlobResourceContents("resource://test", + "application/octet-stream", "base64encodedblob"); + + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + + String value = JSON_MAPPER.writeValueAsString(test); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""")); + } + + @Test + void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( + """ + {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob","_meta":{"metaKey":"metaValue"}}}""", + McpSchema.EmbeddedResource.class); + assertThat(embeddedResource).isNotNull(); + assertThat(embeddedResource.type()).isEqualTo("resource"); + assertThat(embeddedResource.resource()).isNotNull(); + assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); + assertThat(embeddedResource.resource().mimeType()).isEqualTo("application/octet-stream"); + assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).blob()) + .isEqualTo("base64encodedblob"); + assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).meta()).containsKey("metaKey"); + } + + @Test + void testResourceLink() throws Exception { + McpSchema.ResourceLink resourceLink = new McpSchema.ResourceLink("main.rs", "Main file", + "file:///project/src/main.rs", "Primary application entry point", "text/x-rust", null, null, + Map.of("metaKey", "metaValue")); + String value = JSON_MAPPER.writeValueAsString(resourceLink); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource_link","name":"main.rs","title":"Main file","uri":"file:///project/src/main.rs","description":"Primary application entry point","mimeType":"text/x-rust","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testResourceLinkDeserialization() throws Exception { + McpSchema.ResourceLink resourceLink = JSON_MAPPER.readValue( + """ + {"type":"resource_link","name":"main.rs","uri":"file:///project/src/main.rs","description":"Primary application entry point","mimeType":"text/x-rust","_meta":{"metaKey":"metaValue"}}""", + McpSchema.ResourceLink.class); + assertThat(resourceLink).isNotNull(); + assertThat(resourceLink.type()).isEqualTo("resource_link"); + assertThat(resourceLink.name()).isEqualTo("main.rs"); + assertThat(resourceLink.uri()).isEqualTo("file:///project/src/main.rs"); + assertThat(resourceLink.description()).isEqualTo("Primary application entry point"); + assertThat(resourceLink.mimeType()).isEqualTo("text/x-rust"); + assertThat(resourceLink.meta()).containsEntry("metaKey", "metaValue"); + } + + // JSON-RPC Message Types Tests + + @Test + void testJSONRPCRequest() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, + params); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"}}""")); + } + + @Test + void testJSONRPCNotification() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + "notification_method", params); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","method":"notification_method","params":{"key":"value"}}""")); + } + + @Test + void testJSONRPCResponse() throws Exception { + Map result = new HashMap<>(); + result.put("result_key", "result_value"); + + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); + + String value = JSON_MAPPER.writeValueAsString(response); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","id":1,"result":{"result_key":"result_value"}}""")); + } + + @Test + void testJSONRPCResponseWithError() throws Exception { + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); + + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); + + String value = JSON_MAPPER.writeValueAsString(response); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid request"}}""")); + } + + // Initialization Tests + + @Test + void testInitializeRequest() throws Exception { + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .roots(true) + .sampling() + .build(); + + McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.InitializeRequest request = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2024_11_05, + capabilities, clientInfo, meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"protocolVersion":"2024-11-05","capabilities":{"roots":{"listChanged":true},"sampling":{}},"clientInfo":{"name":"test-client","version":"1.0.0"},"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testInitializeResult() throws Exception { + McpSchema.ServerCapabilities capabilities = McpSchema.ServerCapabilities.builder() + .logging() + .prompts(true) + .resources(true, true) + .tools(true) + .build(); + + McpSchema.Implementation serverInfo = new McpSchema.Implementation("test-server", "1.0.0"); + + McpSchema.InitializeResult result = new McpSchema.InitializeResult(ProtocolVersions.MCP_2024_11_05, + capabilities, serverInfo, "Server initialized successfully"); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{"listChanged":true},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"test-server","version":"1.0.0"},"instructions":"Server initialized successfully"}""")); + } + + // Resource Tests + + @Test + void testResource() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", + "text/plain", annotations); + + String value = JSON_MAPPER.writeValueAsString(resource); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","annotations":{"audience":["user","assistant"],"priority":0.8}}""")); + } + + @Test + void testResourceBuilder() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource resource = McpSchema.Resource.builder() + .uri("resource://test") + .name("Test Resource") + .description("A test resource") + .mimeType("text/plain") + .size(256L) + .annotations(annotations) + .meta(Map.of("metaKey", "metaValue")) + .build(); + + String value = JSON_MAPPER.writeValueAsString(resource); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","size":256,"annotations":{"audience":["user","assistant"],"priority":0.8},"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testResourceBuilderUriRequired() { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource.Builder resourceBuilder = McpSchema.Resource.builder() + .name("Test Resource") + .description("A test resource") + .mimeType("text/plain") + .size(256L) + .annotations(annotations); + + assertThatThrownBy(resourceBuilder::build).isInstanceOf(java.lang.IllegalArgumentException.class); + } + + @Test + void testResourceBuilderNameRequired() { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource.Builder resourceBuilder = McpSchema.Resource.builder() + .uri("resource://test") + .description("A test resource") + .mimeType("text/plain") + .size(256L) + .annotations(annotations); + + assertThatThrownBy(resourceBuilder::build).isInstanceOf(java.lang.IllegalArgumentException.class); + } + + @Test + void testResourceTemplate() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations(Arrays.asList(McpSchema.Role.USER), 0.5); + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", + "Test Template", "A test resource template", "text/plain", annotations, meta); + + String value = JSON_MAPPER.writeValueAsString(template); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uriTemplate":"resource://{param}/test","name":"Test Template","title":"Test Template","description":"A test resource template","mimeType":"text/plain","annotations":{"audience":["user"],"priority":0.5},"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testListResourcesResult() throws Exception { + McpSchema.Resource resource1 = new McpSchema.Resource("resource://test1", "Test Resource 1", + "First test resource", "text/plain", null); + + McpSchema.Resource resource2 = new McpSchema.Resource("resource://test2", "Test Resource 2", + "Second test resource", "application/json", null); + + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), + "next-cursor", meta); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"resources":[{"uri":"resource://test1","name":"Test Resource 1","description":"First test resource","mimeType":"text/plain"},{"uri":"resource://test2","name":"Test Resource 2","description":"Second test resource","mimeType":"application/json"}],"nextCursor":"next-cursor","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testListResourceTemplatesResult() throws Exception { + McpSchema.ResourceTemplate template1 = new McpSchema.ResourceTemplate("resource://{param}/test1", + "Test Template 1", "Test Template 1", "First test template", "text/plain", null); + + McpSchema.ResourceTemplate template2 = new McpSchema.ResourceTemplate("resource://{param}/test2", + "Test Template 2", "Test Template 2", "Second test template", "application/json", null); + + McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( + Arrays.asList(template1, template2), "next-cursor"); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"resourceTemplates":[{"uriTemplate":"resource://{param}/test1","name":"Test Template 1","title":"Test Template 1","description":"First test template","mimeType":"text/plain"},{"uriTemplate":"resource://{param}/test2","name":"Test Template 2","title":"Test Template 2","description":"Second test template","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testReadResourceRequest() throws Exception { + McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", + Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"resource://test","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testReadResourceRequestWithMeta() throws Exception { + Map meta = new HashMap<>(); + meta.put("progressToken", "read-resource-token-123"); + + McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"resource://test","_meta":{"progressToken":"read-resource-token-123"}}""")); + + // Test Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("read-resource-token-123"); + } + + @Test + void testReadResourceRequestDeserialization() throws Exception { + McpSchema.ReadResourceRequest request = JSON_MAPPER.readValue(""" + {"uri":"resource://test","_meta":{"progressToken":"test-token"}}""", + McpSchema.ReadResourceRequest.class); + + assertThat(request.uri()).isEqualTo("resource://test"); + assertThat(request.meta()).containsEntry("progressToken", "test-token"); + assertThat(request.progressToken()).isEqualTo("test-token"); + } + + @Test + void testReadResourceResult() throws Exception { + McpSchema.TextResourceContents contents1 = new McpSchema.TextResourceContents("resource://test1", "text/plain", + "Sample text content"); + + McpSchema.BlobResourceContents contents2 = new McpSchema.BlobResourceContents("resource://test2", + "application/octet-stream", "base64encodedblob"); + + McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2), + Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"contents":[{"uri":"resource://test1","mimeType":"text/plain","text":"Sample text content"},{"uri":"resource://test2","mimeType":"application/octet-stream","blob":"base64encodedblob"}],"_meta":{"metaKey":"metaValue"}}""")); + } + + // Prompt Tests + + @Test + void testPrompt() throws Exception { + McpSchema.PromptArgument arg1 = new McpSchema.PromptArgument("arg1", "First argument", "First argument", true); + + McpSchema.PromptArgument arg2 = new McpSchema.PromptArgument("arg2", "Second argument", "Second argument", + false); + + McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "Test Prompt", "A test prompt", + Arrays.asList(arg1, arg2), Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(prompt); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-prompt","title":"Test Prompt","description":"A test prompt","arguments":[{"name":"arg1","title":"First argument","description":"First argument","required":true},{"name":"arg2","title":"Second argument","description":"Second argument","required":false}],"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testPromptMessage() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Hello, world!"); + + McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); + + String value = JSON_MAPPER.writeValueAsString(message); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"role":"user","content":{"type":"text","text":"Hello, world!"}}""")); + } + + @Test + void testListPromptsResult() throws Exception { + McpSchema.PromptArgument arg = new McpSchema.PromptArgument("arg", "Argument", "An argument", true); + + McpSchema.Prompt prompt1 = new McpSchema.Prompt("prompt1", "First prompt", "First prompt", + Collections.singletonList(arg)); + + McpSchema.Prompt prompt2 = new McpSchema.Prompt("prompt2", "Second prompt", "Second prompt", + Collections.emptyList()); + + McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), + "next-cursor"); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"prompts":[{"name":"prompt1","title":"First prompt","description":"First prompt","arguments":[{"name":"arg","title":"Argument","description":"An argument","required":true}]},{"name":"prompt2","title":"Second prompt","description":"Second prompt","arguments":[]}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testGetPromptRequest() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("arg1", "value1"); + arguments.put("arg2", 42); + + McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); + + assertThat(JSON_MAPPER.readValue(""" + {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) + .isEqualTo(request); + } + + @Test + void testGetPromptRequestWithMeta() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("arg1", "value1"); + arguments.put("arg2", 42); + + Map meta = new HashMap<>(); + meta.put("progressToken", "token123"); + + McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments, meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42},"_meta":{"progressToken":"token123"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("token123"); + } + + @Test + void testGetPromptResult() throws Exception { + McpSchema.TextContent content1 = new McpSchema.TextContent("System message"); + McpSchema.TextContent content2 = new McpSchema.TextContent("User message"); + + McpSchema.PromptMessage message1 = new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, content1); + + McpSchema.PromptMessage message2 = new McpSchema.PromptMessage(McpSchema.Role.USER, content2); + + McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", + Arrays.asList(message1, message2)); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"description":"A test prompt result","messages":[{"role":"assistant","content":{"type":"text","text":"System message"}},{"role":"user","content":{"type":"text","text":"User message"}}]}""")); + } + + // Tool Tests + + @Test + void testJsonSchema() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/$defs/Address" + } + }, + "required": ["name"], + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = JSON_MAPPER.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + + @Test + void testJsonSchemaWithDefinitions() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/definitions/Address" + } + }, + "required": ["name"], + "definitions": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = JSON_MAPPER.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + + @Test + void testTool() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); + } + + @Test + void testToolWithComplexSchema() throws Exception { + String complexSchemaJson = """ + { + "type": "object", + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "properties": { + "name": {"type": "string"}, + "shippingAddress": {"$ref": "#/$defs/Address"} + }, + "required": ["name", "shippingAddress"] + } + """; + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("Handles addresses") + .inputSchema(JSON_MAPPER, complexSchemaJson) + .build(); + + // Serialize the tool to a string + String serialized = JSON_MAPPER.writeValueAsString(tool); + + // Deserialize back to a Tool object + McpSchema.Tool deserializedTool = JSON_MAPPER.readValue(serialized, McpSchema.Tool.class); + + // Serialize again and compare with first serialization + String serializedAgain = JSON_MAPPER.writeValueAsString(deserializedTool); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + + // Just verify the basic structure was preserved + assertThat(deserializedTool.inputSchema().defs()).isNotNull(); + assertThat(deserializedTool.inputSchema().defs()).containsKey("Address"); + } + + @Test + void testToolWithMeta() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("addressTool") + .description("Handles addresses") + .inputSchema(schema) + .meta(meta) + .build(); + + // Verify that meta value was preserved + assertThat(tool.meta()).isNotNull(); + assertThat(tool.meta()).containsKey("metaKey"); + } + + @Test + void testToolWithAnnotations() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool", false, false, false, false, + false); + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .annotations(annotations) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + { + "name":"test-tool", + "description":"A test tool", + "inputSchema":{ + "type":"object", + "properties":{ + "name":{"type":"string"}, + "value":{"type":"number"} + }, + "required":["name"] + }, + "annotations":{ + "title":"A test tool", + "readOnlyHint":false, + "destructiveHint":false, + "idempotentHint":false, + "openWorldHint":false, + "returnDirect":false + } + } + """)); + } + + @Test + void testToolWithOutputSchema() throws Exception { + String inputSchemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + String outputSchemaJson = """ + { + "type": "object", + "properties": { + "result": { + "type": "string" + }, + "status": { + "type": "string", + "enum": ["success", "error"] + } + }, + "required": ["result", "status"] + } + """; + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + { + "name":"test-tool", + "description":"A test tool", + "inputSchema":{ + "type":"object", + "properties":{ + "name":{"type":"string"}, + "value":{"type":"number"} + }, + "required":["name"] + }, + "outputSchema":{ + "type":"object", + "properties":{ + "result":{"type":"string"}, + "status":{ + "type":"string", + "enum":["success","error"] + } + }, + "required":["result","status"] + } + } + """)); + } + + @Test + void testToolWithOutputSchemaAndAnnotations() throws Exception { + String inputSchemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + } + }, + "required": ["name"] + } + """; + + String outputSchemaJson = """ + { + "type": "object", + "properties": { + "result": { + "type": "string" + } + }, + "required": ["result"] + } + """; + + McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool with output", true, false, + true, false, true); + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .annotations(annotations) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + { + "name":"test-tool", + "description":"A test tool", + "inputSchema":{ + "type":"object", + "properties":{ + "name":{"type":"string"} + }, + "required":["name"] + }, + "outputSchema":{ + "type":"object", + "properties":{ + "result":{"type":"string"} + }, + "required":["result"] + }, + "annotations":{ + "title":"A test tool with output", + "readOnlyHint":true, + "destructiveHint":false, + "idempotentHint":true, + "openWorldHint":false, + "returnDirect":true + } + }""")); + } + + @Test + void testToolDeserialization() throws Exception { + String toolJson = """ + { + "name": "test-tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"] + }, + "outputSchema": { + "type": "object", + "properties": { + "result": {"type": "string"} + }, + "required": ["result"] + }, + "annotations": { + "title": "Test Tool", + "readOnlyHint": true, + "destructiveHint": false, + "idempotentHint": true, + "openWorldHint": false, + "returnDirect": false + } + } + """; + + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); + + assertThat(tool).isNotNull(); + assertThat(tool.name()).isEqualTo("test-tool"); + assertThat(tool.description()).isEqualTo("A test tool"); + assertThat(tool.inputSchema()).isNotNull(); + assertThat(tool.inputSchema().type()).isEqualTo("object"); + assertThat(tool.outputSchema()).isNotNull(); + assertThat(tool.outputSchema()).containsKey("type"); + assertThat(tool.outputSchema().get("type")).isEqualTo("object"); + assertThat(tool.annotations()).isNotNull(); + assertThat(tool.annotations().title()).isEqualTo("Test Tool"); + assertThat(tool.annotations().readOnlyHint()).isTrue(); + assertThat(tool.annotations().idempotentHint()).isTrue(); + assertThat(tool.annotations().destructiveHint()).isFalse(); + assertThat(tool.annotations().returnDirect()).isFalse(); + } + + @Test + void testToolDeserializationWithoutOutputSchema() throws Exception { + String toolJson = """ + { + "name": "test-tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"] + } + } + """; + + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); + + assertThat(tool).isNotNull(); + assertThat(tool.name()).isEqualTo("test-tool"); + assertThat(tool.description()).isEqualTo("A test tool"); + assertThat(tool.inputSchema()).isNotNull(); + assertThat(tool.outputSchema()).isNull(); + assertThat(tool.annotations()).isNull(); + } + + @Test + void testCallToolRequest() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("name", "test"); + arguments.put("value", 42); + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"name":"test-tool","arguments":{"name":"test","value":42}}""")); + } + + @Test + void testCallToolRequestJsonArguments() throws Exception { + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(JSON_MAPPER, "test-tool", """ + { + "name": "test", + "value": 42 + } + """); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"name":"test-tool","arguments":{"name":"test","value":42}}""")); + } + + @Test + void testCallToolRequestWithMeta() throws Exception { + + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of("name", "test", "value", 42)) + .progressToken("tool-progress-123") + .build(); + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","arguments":{"name":"test","value":42},"_meta":{"progressToken":"tool-progress-123"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(Map.of("progressToken", "tool-progress-123")); + assertThat(request.progressToken()).isEqualTo("tool-progress-123"); + } + + @Test + void testCallToolRequestBuilderWithJsonArguments() throws Exception { + Map meta = new HashMap<>(); + meta.put("progressToken", "json-builder-789"); + + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("test-tool") + .arguments(JSON_MAPPER, """ + { + "name": "test", + "value": 42 + } + """) + .meta(meta) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","arguments":{"name":"test","value":42},"_meta":{"progressToken":"json-builder-789"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("json-builder-789"); + } + + @Test + void testCallToolRequestBuilderNameRequired() { + Map arguments = new HashMap<>(); + arguments.put("name", "test"); + + McpSchema.CallToolRequest.Builder builder = McpSchema.CallToolRequest.builder().arguments(arguments); + + assertThatThrownBy(builder::build).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("name must not be empty"); + } + + @Test + void testCallToolResult() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .content(Collections.singletonList(content)) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilder() throws Exception { + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addTextContent("Tool execution result") + .isError(false) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilderWithMultipleContents() throws Exception { + McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addContent(textContent) + .addContent(imageContent) + .isError(false) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilderWithContentList() throws Exception { + McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + List contents = Arrays.asList(textContent, imageContent); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder().content(contents).isError(true).build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":true}""")); + } + + @Test + void testCallToolResultBuilderWithErrorResult() throws Exception { + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addTextContent("Error: Operation failed") + .isError(true) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Error: Operation failed"}],"isError":true}""")); + } + + @Test + void testCallToolResultStringConstructor() throws Exception { + // Test the existing string constructor alongside the builder + McpSchema.CallToolResult result1 = new McpSchema.CallToolResult("Simple result", false); + McpSchema.CallToolResult result2 = McpSchema.CallToolResult.builder() + .addTextContent("Simple result") + .isError(false) + .build(); + + String value1 = JSON_MAPPER.writeValueAsString(result1); + String value2 = JSON_MAPPER.writeValueAsString(result2); + + // Both should produce the same JSON + assertThat(value1).isEqualTo(value2); + assertThatJson(value1).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Simple result"}],"isError":false}""")); + } + + // Sampling Tests + + @Test + void testCreateMessageRequest() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("User message"); + + McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); + + McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); + + McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, + 0.7, 0.9); + + Map metadata = new HashMap<>(); + metadata.put("session", "test-session"); + + McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() + .messages(Collections.singletonList(message)) + .modelPreferences(preferences) + .systemPrompt("You are a helpful assistant") + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) + .temperature(0.7) + .maxTokens(1000) + .stopSequences(Arrays.asList("STOP", "END")) + .metadata(metadata) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"messages":[{"role":"user","content":{"type":"text","text":"User message"}}],"modelPreferences":{"hints":[{"name":"gpt-4"}],"costPriority":0.3,"speedPriority":0.7,"intelligencePriority":0.9},"systemPrompt":"You are a helpful assistant","includeContext":"thisServer","temperature":0.7,"maxTokens":1000,"stopSequences":["STOP","END"],"metadata":{"session":"test-session"}}""")); + } + + @Test + void testCreateMessageResult() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Assistant response"); + + McpSchema.CreateMessageResult result = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(content) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); + } + + @Test + void testCreateMessageResultUnknownStopReason() throws Exception { + String input = """ + {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"arbitrary value"}"""; + + McpSchema.CreateMessageResult value = JSON_MAPPER.readValue(input, McpSchema.CreateMessageResult.class); + + McpSchema.TextContent expectedContent = new McpSchema.TextContent("Assistant response"); + McpSchema.CreateMessageResult expected = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(expectedContent) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.UNKNOWN) + .build(); + assertThat(value).isEqualTo(expected); + } + + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + + @Test + void testElicitRequestWithMeta() throws Exception { + Map requestedSchema = Map.of("type", "object", "required", List.of("name"), "properties", + Map.of("name", Map.of("type", "string"))); + + Map meta = new HashMap<>(); + meta.put("progressToken", "elicit-token-789"); + + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .requestedSchema(requestedSchema) + .meta(meta) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .containsEntry("_meta", Map.of("progressToken", "elicit-token-789")); + + // Test Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("elicit-token-789"); + } + + // Pagination Tests + + @Test + void testPaginatedRequestNoArgs() throws Exception { + McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isNull(); + assertThat(request.progressToken()).isNull(); + } + + @Test + void testPaginatedRequestWithCursor() throws Exception { + McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123"); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"cursor":"cursor123"}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isNull(); + assertThat(request.progressToken()).isNull(); + } + + @Test + void testPaginatedRequestWithMeta() throws Exception { + Map meta = new HashMap<>(); + meta.put("progressToken", "pagination-progress-456"); + + McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123", meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"cursor":"cursor123","_meta":{"progressToken":"pagination-progress-456"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("pagination-progress-456"); + } + + @Test + void testPaginatedRequestDeserialization() throws Exception { + McpSchema.PaginatedRequest request = JSON_MAPPER.readValue(""" + {"cursor":"test-cursor","_meta":{"progressToken":"test-token"}}""", McpSchema.PaginatedRequest.class); + + assertThat(request.cursor()).isEqualTo("test-cursor"); + assertThat(request.meta()).containsEntry("progressToken", "test-token"); + assertThat(request.progressToken()).isEqualTo("test-token"); + } + + // Complete Request Tests + + @Test + void testCompleteRequest() throws Exception { + McpSchema.PromptReference promptRef = new McpSchema.PromptReference("test-prompt"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument("arg1", + "partial-value"); + + McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(promptRef, argument); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"ref":{"type":"ref/prompt","name":"test-prompt"},"argument":{"name":"arg1","value":"partial-value"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isNull(); + assertThat(request.progressToken()).isNull(); + } + + @Test + void testCompleteRequestWithMeta() throws Exception { + McpSchema.ResourceReference resourceRef = new McpSchema.ResourceReference("file:///test.txt"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument("path", + "/partial/path"); + + Map meta = new HashMap<>(); + meta.put("progressToken", "complete-progress-789"); + + McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(resourceRef, argument, meta, null); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"ref":{"type":"ref/resource","uri":"file:///test.txt"},"argument":{"name":"path","value":"/partial/path"},"_meta":{"progressToken":"complete-progress-789"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("complete-progress-789"); + } + + // Roots Tests + + @Test + void testRoot() throws Exception { + McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root", Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(root); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"file:///path/to/root","name":"Test Root","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testListRootsResult() throws Exception { + McpSchema.Root root1 = new McpSchema.Root("file:///path/to/root1", "First Root"); + + McpSchema.Root root2 = new McpSchema.Root("file:///path/to/root2", "Second Root"); + + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2), "next-cursor"); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"roots":[{"uri":"file:///path/to/root1","name":"First Root"},{"uri":"file:///path/to/root2","name":"Second Root"}],"nextCursor":"next-cursor"}""")); + + } + + // Elicitation Capability Tests (Issue #724) + + @Test + void testElicitationCapabilityWithFormField() throws Exception { + // Test that elicitation with "form" field can be deserialized (2025-11-25 spec) + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityWithFormAndUrlFields() throws Exception { + // Test that elicitation with both "form" and "url" fields can be deserialized + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{},"url":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBackwardCompatibilityEmptyObject() throws Exception { + // Test backward compatibility: empty elicitation {} should still work + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderBackwardCompatibility() throws Exception { + // Test that the existing builder API still works and produces valid JSON + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder().elicitation().build(); + + assertThat(capabilities.elicitation()).isNotNull(); + + // Serialize and verify it produces valid JSON (should be {} for backward compat) + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"elicitation\""); + } + + @Test + void testElicitationCapabilitySerializationRoundTrip() throws Exception { + // Test that serialization and deserialization round-trip works + McpSchema.ClientCapabilities original = McpSchema.ClientCapabilities.builder().elicitation().build(); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ClientCapabilities deserialized = JSON_MAPPER.readValue(json, McpSchema.ClientCapabilities.class); + + assertThat(deserialized.elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderWithFormAndUrl() throws Exception { + // Test the new builder method that explicitly sets form and url support + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, true) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNotNull(); + + // Verify serialization produces the expected JSON + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThatJson(json).when(Option.IGNORING_ARRAY_ORDER).isObject().containsKey("elicitation"); + assertThat(json).contains("\"form\""); + assertThat(json).contains("\"url\""); + } + + @Test + void testElicitationCapabilityBuilderFormOnly() throws Exception { + // Test builder with form only + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, false) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNull(); + + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"form\""); + assertThat(json).doesNotContain("\"url\""); + } + + // Progress Notification Tests + + @Test + void testProgressNotificationWithMessage() throws Exception { + McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-123", 0.5, 1.0, + "Processing file 1 of 2", Map.of("key", "value")); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"progressToken":"progress-token-123","progress":0.5,"total":1.0,"message":"Processing file 1 of 2","_meta":{"key":"value"}}""")); + } + + @Test + void testProgressNotificationDeserialization() throws Exception { + McpSchema.ProgressNotification notification = JSON_MAPPER.readValue( + """ + {"progressToken":"token-456","progress":0.75,"total":1.0,"message":"Almost done","_meta":{"key":"value"}}""", + McpSchema.ProgressNotification.class); + + assertThat(notification.progressToken()).isEqualTo("token-456"); + assertThat(notification.progress()).isEqualTo(0.75); + assertThat(notification.total()).isEqualTo(1.0); + assertThat(notification.message()).isEqualTo("Almost done"); + assertThat(notification.meta()).containsEntry("key", "value"); + } + + @Test + void testProgressNotificationWithoutMessage() throws Exception { + McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-789", 0.25, + null, null); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"progressToken":"progress-token-789","progress":0.25}""")); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java new file mode 100644 index 000000000..1d7be0b51 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java @@ -0,0 +1,100 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Test class to verify the equals method implementation for PromptReference. + */ +class PromptReferenceEqualsTest { + + @Test + void testEqualsWithSameIdentifierAndType() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); + + assertTrue(ref1.equals(ref2), "PromptReferences with same identifier and type should be equal"); + assertEquals(ref1.hashCode(), ref2.hashCode(), "Equal objects should have same hash code"); + } + + @Test + void testEqualsWithDifferentIdentifier() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-1", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-2", + "Test Title"); + + assertFalse(ref1.equals(ref2), "PromptReferences with different identifiers should not be equal"); + } + + @Test + void testEqualsWithDifferentType() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/other", "test-prompt", "Test Title"); + + assertFalse(ref1.equals(ref2), "PromptReferences with different types should not be equal"); + } + + @Test + void testEqualsWithNull() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + + assertFalse(ref1.equals(null), "PromptReference should not be equal to null"); + } + + @Test + void testEqualsWithDifferentClass() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + String other = "not a PromptReference"; + + assertFalse(ref1.equals(other), "PromptReference should not be equal to different class"); + } + + @Test + void testEqualsWithSameInstance() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + + assertTrue(ref1.equals(ref1), "PromptReference should be equal to itself"); + } + + @Test + void testEqualsIgnoresTitle() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 1"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 2"); + McpSchema.PromptReference ref3 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", null); + + assertTrue(ref1.equals(ref2), "PromptReferences should be equal regardless of title"); + assertTrue(ref1.equals(ref3), "PromptReferences should be equal even when one has null title"); + assertTrue(ref2.equals(ref3), "PromptReferences should be equal even when one has null title"); + } + + @Test + void testHashCodeConsistency() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); + + assertEquals(ref1.hashCode(), ref2.hashCode(), "Objects that are equal should have the same hash code"); + + // Call hashCode multiple times to ensure consistency + int hashCode1 = ref1.hashCode(); + int hashCode2 = ref1.hashCode(); + assertEquals(hashCode1, hashCode2, "Hash code should be consistent across multiple calls"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java new file mode 100644 index 000000000..ef7cd2737 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java @@ -0,0 +1,97 @@ +package io.modelcontextprotocol.spec.json.gson; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.ToNumberPolicy; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +/** + * Test-only Gson-based implementation of McpJsonMapper. This lives under src/test/java so + * it doesn't affect production code or dependencies. + */ +public final class GsonMcpJsonMapper implements McpJsonMapper { + + private final Gson gson; + + public GsonMcpJsonMapper() { + this(new GsonBuilder().serializeNulls() + // Ensure numeric values in untyped (Object) fields preserve integral numbers + // as Long + .setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .setNumberToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .create()); + } + + public GsonMcpJsonMapper(Gson gson) { + if (gson == null) { + throw new IllegalArgumentException("Gson must not be null"); + } + this.gson = gson; + } + + public Gson getGson() { + return gson; + } + + @Override + public T readValue(String content, Class type) throws IOException { + try { + return gson.fromJson(content, type); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + try { + return gson.fromJson(content, type.getType()); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T convertValue(Object fromValue, Class type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type.getType()); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + try { + return gson.toJson(value); + } + catch (Exception e) { + throw new IOException("Failed to serialize to JSON", e); + } + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return writeValueAsString(value).getBytes(StandardCharsets.UTF_8); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java new file mode 100644 index 000000000..498194d17 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java @@ -0,0 +1,135 @@ +package io.modelcontextprotocol.spec.json.gson; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +class GsonMcpJsonMapperTests { + + record Person(String name, int age) { + } + + @Test + void roundTripSimplePojo() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + var input = new Person("Alice", 30); + String json = mapper.writeValueAsString(input); + assertNotNull(json); + assertTrue(json.contains("\"Alice\"")); + assertTrue(json.contains("\"age\"")); + + var decoded = mapper.readValue(json, Person.class); + assertEquals(input, decoded); + + byte[] bytes = mapper.writeValueAsBytes(input); + assertNotNull(bytes); + var decodedFromBytes = mapper.readValue(bytes, Person.class); + assertEquals(input, decodedFromBytes); + } + + @Test + void readWriteParameterizedTypeWithTypeRef() throws IOException { + var mapper = new GsonMcpJsonMapper(); + String json = "[\"a\", \"b\", \"c\"]"; + + List list = mapper.readValue(json, new TypeRef>() { + }); + assertEquals(List.of("a", "b", "c"), list); + + String encoded = mapper.writeValueAsString(list); + assertTrue(encoded.startsWith("[")); + assertTrue(encoded.contains("\"a\"")); + } + + @Test + void convertValueMapToRecordAndParameterized() { + var mapper = new GsonMcpJsonMapper(); + Map src = Map.of("name", "Bob", "age", 42); + + // Convert to simple record + Person person = mapper.convertValue(src, Person.class); + assertEquals(new Person("Bob", 42), person); + + // Convert to parameterized Map + Map toMap = mapper.convertValue(person, new TypeRef>() { + }); + assertEquals("Bob", toMap.get("name")); + assertEquals(42.0, ((Number) toMap.get("age")).doubleValue(), 0.0); // Gson may + // emit double + // for + // primitives + } + + @Test + void deserializeJsonRpcMessageRequestUsingCustomMapper() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + String json = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "ping", + "params": { "x": 1, "y": "z" } + } + """; + + var msg = McpSchema.deserializeJsonRpcMessage(mapper, json); + assertTrue(msg instanceof McpSchema.JSONRPCRequest); + + var req = (McpSchema.JSONRPCRequest) msg; + assertEquals("2.0", req.jsonrpc()); + assertEquals("ping", req.method()); + assertNotNull(req.id()); + assertEquals("1", req.id().toString()); + + assertNotNull(req.params()); + assertInstanceOf(Map.class, req.params()); + @SuppressWarnings("unchecked") + var params = (Map) req.params(); + assertEquals(1.0, ((Number) params.get("x")).doubleValue(), 0.0); + assertEquals("z", params.get("y")); + } + + @Test + void integrateWithMcpSchemaStaticMapperForStringParsing() { + var gsonMapper = new GsonMcpJsonMapper(); + + // Tool builder parsing of input/output schema strings + var tool = McpSchema.Tool.builder().name("echo").description("Echo tool").inputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "x": { "type": "integer" } }, + "required": ["x"] + } + """).outputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "y": { "type": "string" } } + } + """).build(); + + assertNotNull(tool.inputSchema()); + assertNotNull(tool.outputSchema()); + assertTrue(tool.outputSchema().containsKey("properties")); + + // CallToolRequest builder parsing of JSON arguments string + var call = McpSchema.CallToolRequest.builder().name("echo").arguments(gsonMapper, "{\"x\": 123}").build(); + + assertEquals("echo", call.name()); + assertNotNull(call.arguments()); + assertTrue(call.arguments().get("x") instanceof Number); + assertEquals(123.0, ((Number) call.arguments().get("x")).doubleValue(), 0.0); + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java new file mode 100644 index 000000000..0038d4e1b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +class AssertTests { + + @Test + void testCollectionNotEmpty() { + IllegalArgumentException e1 = assertThrows(IllegalArgumentException.class, + () -> Assert.notEmpty(null, "collection is null")); + assertEquals("collection is null", e1.getMessage()); + + IllegalArgumentException e2 = assertThrows(IllegalArgumentException.class, + () -> Assert.notEmpty(List.of(), "collection is empty")); + assertEquals("collection is empty", e2.getMessage()); + + assertDoesNotThrow(() -> Assert.notEmpty(List.of("test"), "collection is not empty")); + } + + @Test + void testObjectNotNull() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Assert.notNull(null, "object is null")); + assertEquals("object is null", e.getMessage()); + + assertDoesNotThrow(() -> Assert.notNull("test", "object is not null")); + } + + @Test + void testStringHasText() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Assert.hasText(null, "string is null")); + assertEquals("string is null", e.getMessage()); + + assertDoesNotThrow(() -> Assert.hasText("test", "string is not empty")); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java new file mode 100644 index 000000000..d5ef8a91c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -0,0 +1,303 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.scheduler.VirtualTimeScheduler; + +/** + * Unit tests for {@link KeepAliveScheduler}. + * + * @author Christian Tzolov + */ +class KeepAliveSchedulerTests { + + private MockMcpSession mockSession1; + + private MockMcpSession mockSession2; + + private Supplier> mockSessionsSupplier; + + private VirtualTimeScheduler virtualTimeScheduler; + + @BeforeEach + void setUp() { + virtualTimeScheduler = VirtualTimeScheduler.create(); + mockSession1 = new MockMcpSession(); + mockSession2 = new MockMcpSession(); + mockSessionsSupplier = () -> Flux.just(mockSession1); + } + + @AfterEach + void tearDown() { + if (virtualTimeScheduler != null) { + virtualTimeScheduler.dispose(); + } + } + + @Test + void testBuilderWithNullSessionsSupplier() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("McpSessions supplier must not be null"); + } + + @Test + void testBuilderWithNullScheduler() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).scheduler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Scheduler must not be null"); + } + + @Test + void testBuilderWithNullInitialDelay() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).initialDelay(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Initial delay must not be null"); + } + + @Test + void testBuilderWithNullInterval() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).interval(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Interval must not be null"); + } + + @Test + void testBuilderDefaults() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier).build(); + + assertThat(scheduler).isNotNull(); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testStartWithMultipleSessions() { + mockSessionsSupplier = () -> Flux.just(mockSession1, mockSession2); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + assertThat(scheduler.isRunning()).isFalse(); + + // Start the scheduler + Disposable disposable = scheduler.start(); + + assertThat(scheduler.isRunning()).isTrue(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + + // Advance time to trigger the first ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify both sessions received ping + assertThat(mockSession1.getPingCount()).isEqualTo(1); + assertThat(mockSession2.getPingCount()).isEqualTo(1); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Second ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Third ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Fourth ping + + // Verify second ping was sent + assertThat(mockSession1.getPingCount()).isEqualTo(4); + assertThat(mockSession2.getPingCount()).isEqualTo(4); + + // Clean up + scheduler.stop(); + + assertThat(scheduler.isRunning()).isFalse(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isTrue(); + } + + @Test + void testStartWithEmptySessionsList() { + mockSessionsSupplier = () -> Flux.empty(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger ping attempts + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify no sessions were called (since list was empty) + assertThat(mockSession1.getPingCount()).isEqualTo(0); + assertThat(mockSession2.getPingCount()).isEqualTo(0); + + // Clean up + scheduler.stop(); + } + + @Test + void testStartWhenAlreadyRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + + // Try to start again - should throw exception + assertThatThrownBy(scheduler::start).isInstanceOf(IllegalStateException.class) + .hasMessage("KeepAlive scheduler is already running. Stop it first."); + + // Clean up + scheduler.stop(); + } + + @Test + void testStopWhenNotRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Should not throw exception when stopping a non-running scheduler + assertDoesNotThrow(scheduler::stop); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testShutdown() { + // Setup with a separate virtual time scheduler (which is disposable) + VirtualTimeScheduler separateScheduler = VirtualTimeScheduler.create(); + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(separateScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + assertThat(scheduler.isRunning()).isTrue(); + + // Shutdown should stop the scheduler and dispose the scheduler + scheduler.shutdown(); + assertThat(scheduler.isRunning()).isFalse(); + assertThat(separateScheduler.isDisposed()).isTrue(); + } + + @Test + void testPingFailureHandling() { + // Setup session that fails ping + mockSession1.setShouldFailPing(true); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger the ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify ping was attempted (error should be handled gracefully) + assertThat(mockSession1.getPingCount()).isEqualTo(1); + + // Scheduler should still be running despite the error + assertThat(scheduler.isRunning()).isTrue(); + + // Clean up + scheduler.stop(); + } + + @Test + void testDisposableReturnedFromStart() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start and get disposable + Disposable disposable = scheduler.start(); + + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + assertThat(scheduler.isRunning()).isTrue(); + + // Dispose directly through the returned disposable + disposable.dispose(); + + assertThat(disposable.isDisposed()).isTrue(); + assertThat(scheduler.isRunning()).isFalse(); + } + + /** + * Simple mock implementation of McpSession for testing purposes. + */ + private static class MockMcpSession implements McpSession { + + private final AtomicInteger pingCount = new AtomicInteger(0); + + private boolean shouldFailPing = false; + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + if (McpSchema.METHOD_PING.equals(method)) { + pingCount.incrementAndGet(); + if (shouldFailPing) { + return Mono.error(new RuntimeException("Connection failed")); + } + return Mono.just((T) new Object()); + } + return Mono.empty(); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + // No-op for mock + } + + public int getPingCount() { + return pingCount.get(); + } + + public void setShouldFailPing(boolean shouldFailPing) { + this.shouldFailPing = shouldFailPing; + } + + @Override + public String toString() { + return "MockMcpSession"; + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..911506e01 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.getDefault(); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java new file mode 100644 index 000000000..0f2e689b5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +class UtilsTests { + + @Test + void testHasText() { + assertFalse(Utils.hasText(null)); + assertFalse(Utils.hasText("")); + assertFalse(Utils.hasText(" ")); + assertTrue(Utils.hasText("test")); + } + + @Test + void testCollectionIsEmpty() { + assertTrue(Utils.isEmpty((Collection) null)); + assertTrue(Utils.isEmpty(List.of())); + assertFalse(Utils.isEmpty(List.of("test"))); + } + + @Test + void testMapIsEmpty() { + assertTrue(Utils.isEmpty((Map) null)); + assertTrue(Utils.isEmpty(Map.of())); + assertFalse(Utils.isEmpty(Map.of("key", "value"))); + } + + @ParameterizedTest + @CsvSource({ + // relative endpoints + "http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1", + "http://localhost:8080/root/, api, http://localhost:8080/root/api", + "http://localhost:8080, /api, http://localhost:8080/api", + // absolute endpoints matching base + "http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1", + "http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" }) + void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) { + URI result = Utils.resolveUri(URI.create(baseUrl), endpoint); + assertThat(result.toString()).isEqualTo(expectedResult); + } + + @ParameterizedTest + @CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api", + "http://localhost:8080/root, http://otherhost/api", + "http://localhost:8080/root, http://localhost:9090/root/api" }) + void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) { + assertThatThrownBy(() -> Utils.resolveUri(URI.create(baseUrl), endpoint)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("does not match the base URL"); + } + +} \ No newline at end of file diff --git a/mcp/src/test/resources/logback.xml b/mcp-core/src/test/resources/logback.xml similarity index 100% rename from mcp/src/test/resources/logback.xml rename to mcp-core/src/test/resources/logback.xml diff --git a/mcp-json-jackson2/pom.xml b/mcp-json-jackson2/pom.xml new file mode 100644 index 000000000..de2ac58ce --- /dev/null +++ b/mcp-json-jackson2/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-json-jackson2 + jar + Java MCP SDK JSON Jackson + Java MCP SDK JSON implementation based on Jackson + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + + + + io.modelcontextprotocol.sdk + mcp-json + 0.18.0-SNAPSHOT + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.networknt + json-schema-validator + ${json-schema-validator.version} + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java new file mode 100644 index 000000000..6aa2b4ebc --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; + +/** + * Jackson-based implementation of JsonMapper. Wraps a Jackson ObjectMapper but keeps the + * SDK decoupled from Jackson at the API level. + */ +public final class JacksonMcpJsonMapper implements McpJsonMapper { + + private final ObjectMapper objectMapper; + + /** + * Constructs a new JacksonMcpJsonMapper instance with the given ObjectMapper. + * @param objectMapper the ObjectMapper to be used for JSON serialization and + * deserialization. Must not be null. + * @throws IllegalArgumentException if the provided ObjectMapper is null. + */ + public JacksonMcpJsonMapper(ObjectMapper objectMapper) { + if (objectMapper == null) { + throw new IllegalArgumentException("ObjectMapper must not be null"); + } + this.objectMapper = objectMapper; + } + + /** + * Returns the underlying Jackson {@link ObjectMapper} used for JSON serialization and + * deserialization. + * @return the ObjectMapper instance + */ + public ObjectMapper getObjectMapper() { + return objectMapper; + } + + @Override + public T readValue(String content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T convertValue(Object fromValue, Class type) { + return objectMapper.convertValue(fromValue, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.convertValue(fromValue, javaType); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + return objectMapper.writeValueAsString(value); + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return objectMapper.writeValueAsBytes(value); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java new file mode 100644 index 000000000..0e79c3e0e --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.McpJsonMapperSupplier; + +/** + * A supplier of {@link McpJsonMapper} instances that uses the Jackson library for JSON + * serialization and deserialization. + *

    + * This implementation provides a {@link McpJsonMapper} backed by a Jackson + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + */ +public class JacksonMcpJsonMapperSupplier implements McpJsonMapperSupplier { + + /** + * Returns a new instance of {@link McpJsonMapper} that uses the Jackson library for + * JSON serialization and deserialization. + *

    + * The returned {@link McpJsonMapper} is backed by a new instance of + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + * @return a new {@link McpJsonMapper} instance + */ + @Override + public McpJsonMapper get() { + return new JacksonMcpJsonMapper(new com.fasterxml.jackson.databind.ObjectMapper()); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java new file mode 100644 index 000000000..1ff28cb80 --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java @@ -0,0 +1,161 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.json.schema.jackson; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.networknt.schema.Schema; +import com.networknt.schema.SchemaRegistry; +import com.networknt.schema.Error; +import com.networknt.schema.dialect.Dialects; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Default implementation of the {@link JsonSchemaValidator} interface. This class + * provides methods to validate structured content against a JSON schema. It uses the + * NetworkNT JSON Schema Validator library for validation. + * + * @author Christian Tzolov + */ +public class DefaultJsonSchemaValidator implements JsonSchemaValidator { + + private static final Logger logger = LoggerFactory.getLogger(DefaultJsonSchemaValidator.class); + + private final ObjectMapper objectMapper; + + private final SchemaRegistry schemaFactory; + + // TODO: Implement a strategy to purge the cache (TTL, size limit, etc.) + private final ConcurrentHashMap schemaCache; + + public DefaultJsonSchemaValidator() { + this(new ObjectMapper()); + } + + public DefaultJsonSchemaValidator(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + this.schemaFactory = SchemaRegistry.withDialect(Dialects.getDraft202012()); + this.schemaCache = new ConcurrentHashMap<>(); + } + + @Override + public ValidationResponse validate(Map schema, Object structuredContent) { + + if (schema == null) { + throw new IllegalArgumentException("Schema must not be null"); + } + if (structuredContent == null) { + throw new IllegalArgumentException("Structured content must not be null"); + } + + try { + + JsonNode jsonStructuredOutput = (structuredContent instanceof String) + ? this.objectMapper.readTree((String) structuredContent) + : this.objectMapper.valueToTree(structuredContent); + + List validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); + + // Check if validation passed + if (!validationResult.isEmpty()) { + return ValidationResponse + .asInvalid("Validation failed: structuredContent does not match tool outputSchema. " + + "Validation errors: " + validationResult); + } + + return ValidationResponse.asValid(jsonStructuredOutput.toString()); + + } + catch (JsonProcessingException e) { + logger.error("Failed to validate CallToolResult: Error parsing schema: {}", e); + return ValidationResponse.asInvalid("Error parsing tool JSON Schema: " + e.getMessage()); + } + catch (Exception e) { + logger.error("Failed to validate CallToolResult: Unexpected error: {}", e); + return ValidationResponse.asInvalid("Unexpected validation error: " + e.getMessage()); + } + } + + /** + * Gets a cached Schema or creates and caches a new one. + * @param schema the schema map to convert + * @return the compiled Schema + * @throws JsonProcessingException if schema processing fails + */ + private Schema getOrCreateJsonSchema(Map schema) throws JsonProcessingException { + // Generate cache key based on schema content + String cacheKey = this.generateCacheKey(schema); + + // Try to get from cache first + Schema cachedSchema = this.schemaCache.get(cacheKey); + if (cachedSchema != null) { + return cachedSchema; + } + + // Create new schema if not in cache + Schema newSchema = this.createJsonSchema(schema); + + // Cache the schema + Schema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); + return existingSchema != null ? existingSchema : newSchema; + } + + /** + * Creates a new Schema from the given schema map. + * @param schema the schema map + * @return the compiled Schema + * @throws JsonProcessingException if schema processing fails + */ + private Schema createJsonSchema(Map schema) throws JsonProcessingException { + // Convert schema map directly to JsonNode (more efficient than string + // serialization) + JsonNode schemaNode = this.objectMapper.valueToTree(schema); + + // Handle case where ObjectMapper might return null (e.g., in mocked scenarios) + if (schemaNode == null) { + throw new JsonProcessingException("Failed to convert schema to JsonNode") { + }; + } + + return this.schemaFactory.getSchema(schemaNode); + } + + /** + * Generates a cache key for the given schema map. + * @param schema the schema map + * @return a cache key string + */ + protected String generateCacheKey(Map schema) { + if (schema.containsKey("$id")) { + // Use the (optional) "$id" field as the cache key if present + return "" + schema.get("$id"); + } + // Fall back to schema's hash code as a simple cache key + // For more sophisticated caching, could use content-based hashing + return String.valueOf(schema.hashCode()); + } + + /** + * Clears the schema cache. Useful for testing or memory management. + */ + public void clearCache() { + this.schemaCache.clear(); + } + + /** + * Returns the current size of the schema cache. + * @return the number of cached schemas + */ + public int getCacheSize() { + return this.schemaCache.size(); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..86153a538 --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema.jackson; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier; + +/** + * A concrete implementation of {@link JsonSchemaValidatorSupplier} that provides a + * {@link JsonSchemaValidator} instance based on the Jackson library. + * + * @see JsonSchemaValidatorSupplier + * @see JsonSchemaValidator + */ +public class JacksonJsonSchemaValidatorSupplier implements JsonSchemaValidatorSupplier { + + /** + * Returns a new instance of {@link JsonSchemaValidator} that uses the Jackson library + * for JSON schema validation. + * @return A {@link JsonSchemaValidator} instance. + */ + @Override + public JsonSchemaValidator get() { + return new DefaultJsonSchemaValidator(); + } + +} diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier new file mode 100644 index 000000000..8ea66d698 --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapperSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier new file mode 100644 index 000000000..0fb0b7e5a --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.schema.jackson.JacksonJsonSchemaValidatorSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java new file mode 100644 index 000000000..7642f0480 --- /dev/null +++ b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java @@ -0,0 +1,808 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.json.schema.jackson.DefaultJsonSchemaValidator; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; + +/** + * Tests for {@link DefaultJsonSchemaValidator}. + * + * @author Christian Tzolov + */ +class DefaultJsonSchemaValidatorTests { + + private DefaultJsonSchemaValidator validator; + + private ObjectMapper objectMapper; + + @Mock + private ObjectMapper mockObjectMapper; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + validator = new DefaultJsonSchemaValidator(); + objectMapper = new ObjectMapper(); + } + + /** + * Utility method to convert JSON string to Map + */ + private Map toMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + private List> toListMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + @Test + void testDefaultConstructor() { + DefaultJsonSchemaValidator defaultValidator = new DefaultJsonSchemaValidator(); + + String schemaJson = """ + { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + """; + String contentJson = """ + { + "test": "value" + } + """; + + ValidationResponse response = defaultValidator.validate(toMap(schemaJson), toMap(contentJson)); + assertTrue(response.valid()); + } + + @Test + void testConstructorWithObjectMapper() { + ObjectMapper customMapper = new ObjectMapper(); + DefaultJsonSchemaValidator customValidator = new DefaultJsonSchemaValidator(customMapper); + + String schemaJson = """ + { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + """; + String contentJson = """ + { + "test": "value" + } + """; + + ValidationResponse response = customValidator.validate(toMap(schemaJson), toMap(contentJson)); + assertTrue(response.valid()); + } + + @Test + void testValidateWithValidStringSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe", + "age": 30 + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + assertNotNull(response.jsonStructuredOutput()); + } + + @Test + void testValidateWithValidNumberSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "price": {"type": "number", "minimum": 0}, + "quantity": {"type": "integer", "minimum": 1} + }, + "required": ["price", "quantity"] + } + """; + + String contentJson = """ + { + "price": 19.99, + "quantity": 5 + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithValidArraySchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["items"] + } + """; + + String contentJson = """ + { + "items": ["apple", "banana", "cherry"] + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithValidArraySchemaTopLevelArray() { + String schemaJson = """ + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "array", + "items" : { + "type" : "object", + "properties" : { + "city" : { + "type" : "string" + }, + "summary" : { + "type" : "string" + }, + "temperatureC" : { + "type" : "number", + "format" : "float" + } + }, + "required" : [ "city", "summary", "temperatureC" ] + }, + "additionalProperties" : false + } + """; + + String contentJson = """ + [ + { + "city": "London", + "summary": "Generally mild with frequent rainfall. Winters are cool and damp, summers are warm but rarely hot. Cloudy conditions are common throughout the year.", + "temperatureC": 11.3 + }, + { + "city": "New York", + "summary": "Four distinct seasons with hot and humid summers, cold winters with snow, and mild springs and autumns. Precipitation is fairly evenly distributed throughout the year.", + "temperatureC": 12.8 + }, + { + "city": "San Francisco", + "summary": "Mild year-round with a distinctive Mediterranean climate. Famous for summer fog, mild winters, and little temperature variation throughout the year. Very little rainfall in summer months.", + "temperatureC": 14.6 + }, + { + "city": "Tokyo", + "summary": "Humid subtropical climate with hot, wet summers and mild winters. Experiences a rainy season in early summer and occasional typhoons in late summer to early autumn.", + "temperatureC": 15.4 + } + ] + """; + + Map schema = toMap(schemaJson); + + // Validate as JSON string + ValidationResponse response = validator.validate(schema, contentJson); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + + List> structuredContent = toListMap(contentJson); + + // Validate as List> + response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithInvalidTypeSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe", + "age": "thirty" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + assertTrue(response.errorMessage().contains("structuredContent does not match tool outputSchema")); + } + + @Test + void testValidateWithMissingRequiredField() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithAdditionalPropertiesNotAllowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": false + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should not be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithAdditionalPropertiesExplicitlyAllowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithDefaultAdditionalProperties() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithAdditionalPropertiesExplicitlyDisallowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": false + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should not be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithEmptySchema() { + String schemaJson = """ + { + "additionalProperties": true + } + """; + + String contentJson = """ + { + "anything": "goes" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithEmptyContent() { + String schemaJson = """ + { + "type": "object", + "properties": {} + } + """; + + String contentJson = """ + {} + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithNestedObjectSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + """; + + String contentJson = """ + { + "person": { + "name": "John Doe", + "address": { + "street": "123 Main St", + "city": "Anytown" + } + } + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithInvalidNestedObjectSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + """; + + String contentJson = """ + { + "person": { + "name": "John Doe", + "address": { + "street": "123 Main St" + } + } + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithJsonProcessingException() throws Exception { + DefaultJsonSchemaValidator validatorWithMockMapper = new DefaultJsonSchemaValidator(mockObjectMapper); + + Map schema = Map.of("type", "object"); + Map structuredContent = Map.of("key", "value"); + + // This will trigger our null check and throw JsonProcessingException + when(mockObjectMapper.valueToTree(any())).thenReturn(null); + + ValidationResponse response = validatorWithMockMapper.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Error parsing tool JSON Schema")); + assertTrue(response.errorMessage().contains("Failed to convert schema to JsonNode")); + } + + @ParameterizedTest + @MethodSource("provideValidSchemaAndContentPairs") + void testValidateWithVariousValidInputs(Map schema, Map content) { + ValidationResponse response = validator.validate(schema, content); + + assertTrue(response.valid(), "Expected validation to pass for schema: " + schema + " and content: " + content); + assertNull(response.errorMessage()); + } + + @ParameterizedTest + @MethodSource("provideInvalidSchemaAndContentPairs") + void testValidateWithVariousInvalidInputs(Map schema, Map content) { + ValidationResponse response = validator.validate(schema, content); + + assertFalse(response.valid(), "Expected validation to fail for schema: " + schema + " and content: " + content); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + private static Map staticToMap(String json) { + try { + ObjectMapper mapper = new ObjectMapper(); + return mapper.readValue(json, new TypeReference>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + private static Stream provideValidSchemaAndContentPairs() { + return Stream.of( + // Boolean schema + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "flag": {"type": "boolean"} + } + } + """), staticToMap(""" + { + "flag": true + } + """)), + // String with additional properties allowed + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "additionalProperties": true + } + """), staticToMap(""" + { + "name": "test", + "extra": "allowed" + } + """)), + // Array with specific items + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "items": {"type": "number"} + } + } + } + """), staticToMap(""" + { + "numbers": [1.0, 2.5, 3.14] + } + """)), + // Enum validation + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"] + } + } + } + """), staticToMap(""" + { + "status": "active" + } + """))); + } + + private static Stream provideInvalidSchemaAndContentPairs() { + return Stream.of( + // Wrong boolean type + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "flag": {"type": "boolean"} + } + } + """), staticToMap(""" + { + "flag": "true" + } + """)), + // Array with wrong item types + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "items": {"type": "number"} + } + } + } + """), staticToMap(""" + { + "numbers": ["one", "two", "three"] + } + """)), + // Invalid enum value + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"] + } + } + } + """), staticToMap(""" + { + "status": "unknown" + } + """)), + // Minimum constraint violation + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "age": {"type": "integer", "minimum": 0} + } + } + """), staticToMap(""" + { + "age": -5 + } + """))); + } + + @Test + void testValidationResponseToValid() { + String jsonOutput = "{\"test\":\"value\"}"; + ValidationResponse response = ValidationResponse.asValid(jsonOutput); + assertTrue(response.valid()); + assertNull(response.errorMessage()); + assertEquals(jsonOutput, response.jsonStructuredOutput()); + } + + @Test + void testValidationResponseToInvalid() { + String errorMessage = "Test error message"; + ValidationResponse response = ValidationResponse.asInvalid(errorMessage); + assertFalse(response.valid()); + assertEquals(errorMessage, response.errorMessage()); + assertNull(response.jsonStructuredOutput()); + } + + @Test + void testValidationResponseRecord() { + ValidationResponse response1 = new ValidationResponse(true, null, "{\"valid\":true}"); + ValidationResponse response2 = new ValidationResponse(false, "Error", null); + + assertTrue(response1.valid()); + assertNull(response1.errorMessage()); + assertEquals("{\"valid\":true}", response1.jsonStructuredOutput()); + + assertFalse(response2.valid()); + assertEquals("Error", response2.errorMessage()); + assertNull(response2.jsonStructuredOutput()); + + // Test equality + ValidationResponse response3 = new ValidationResponse(true, null, "{\"valid\":true}"); + assertEquals(response1, response3); + assertNotEquals(response1, response2); + } + +} diff --git a/mcp-json/pom.xml b/mcp-json/pom.xml new file mode 100644 index 000000000..2cbcf3516 --- /dev/null +++ b/mcp-json/pom.xml @@ -0,0 +1,39 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-json + jar + Java MCP SDK JSON Support + Java MCP SDK JSON Support API + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + + + + + diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java new file mode 100644 index 000000000..31930ab33 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * Utility class for creating a default {@link McpJsonMapper} instance. This class + * provides a single method to create a default mapper using the {@link ServiceLoader} + * mechanism. + */ +final class McpJsonInternal { + + private static McpJsonMapper defaultJsonMapper = null; + + /** + * Returns the cached default {@link McpJsonMapper} instance. If the default mapper + * has not been created yet, it will be initialized using the + * {@link #createDefaultMapper()} method. + * @return the default {@link McpJsonMapper} instance + * @throws IllegalStateException if no default {@link McpJsonMapper} implementation is + * found + */ + static McpJsonMapper getDefaultMapper() { + if (defaultJsonMapper == null) { + defaultJsonMapper = McpJsonInternal.createDefaultMapper(); + } + return defaultJsonMapper; + } + + /** + * Creates a default {@link McpJsonMapper} instance using the {@link ServiceLoader} + * mechanism. The default mapper is resolved by loading the first available + * {@link McpJsonMapperSupplier} implementation on the classpath. + * @return the default {@link McpJsonMapper} instance + * @throws IllegalStateException if no default {@link McpJsonMapper} implementation is + * found + */ + static McpJsonMapper createDefaultMapper() { + AtomicReference ex = new AtomicReference<>(); + return ServiceLoader.load(McpJsonMapperSupplier.class).stream().flatMap(p -> { + try { + McpJsonMapperSupplier supplier = p.get(); + return Stream.ofNullable(supplier); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).flatMap(jsonMapperSupplier -> { + try { + return Stream.ofNullable(jsonMapperSupplier.get()); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).findFirst().orElseThrow(() -> { + if (ex.get() != null) { + return ex.get(); + } + else { + return new IllegalStateException("No default McpJsonMapper implementation found"); + } + }); + } + + private static void addException(AtomicReference ref, Exception toAdd) { + ref.updateAndGet(existing -> { + if (existing == null) { + return new IllegalStateException("Failed to initialize default McpJsonMapper", toAdd); + } + else { + existing.addSuppressed(toAdd); + return existing; + } + }); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java new file mode 100644 index 000000000..1e30cad16 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java @@ -0,0 +1,110 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.io.IOException; + +/** + * Abstraction for JSON serialization/deserialization to decouple the SDK from any + * specific JSON library. A default implementation backed by Jackson is provided in + * io.modelcontextprotocol.spec.json.jackson.JacksonJsonMapper. + */ +public interface McpJsonMapper { + + /** + * Deserialize JSON string into a target type. + * @param content JSON as String + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, Class type) throws IOException; + + /** + * Deserialize JSON bytes into a target type. + * @param content JSON as bytes + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, Class type) throws IOException; + + /** + * Deserialize JSON string into a parameterized target type. + * @param content JSON as String + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, TypeRef type) throws IOException; + + /** + * Deserialize JSON bytes into a parameterized target type. + * @param content JSON as bytes + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, TypeRef type) throws IOException; + + /** + * Convert a value to a given type, useful for mapping nested JSON structures. + * @param fromValue source value + * @param type target class + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, Class type); + + /** + * Convert a value to a given parameterized type. + * @param fromValue source value + * @param type target type reference + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, TypeRef type); + + /** + * Serialize an object to JSON string. + * @param value object to serialize + * @return JSON as String + * @throws IOException on serialization errors + */ + String writeValueAsString(Object value) throws IOException; + + /** + * Serialize an object to JSON bytes. + * @param value object to serialize + * @return JSON as bytes + * @throws IOException on serialization errors + */ + byte[] writeValueAsBytes(Object value) throws IOException; + + /** + * Returns the default {@link McpJsonMapper}. + * @return The default {@link McpJsonMapper} + * @throws IllegalStateException If no {@link McpJsonMapper} implementation exists on + * the classpath. + */ + static McpJsonMapper getDefault() { + return McpJsonInternal.getDefaultMapper(); + } + + /** + * Creates a new default {@link McpJsonMapper}. + * @return The default {@link McpJsonMapper} + * @throws IllegalStateException If no {@link McpJsonMapper} implementation exists on + * the classpath. + */ + static McpJsonMapper createDefault() { + return McpJsonInternal.createDefaultMapper(); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java new file mode 100644 index 000000000..619f96040 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java @@ -0,0 +1,14 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.function.Supplier; + +/** + * Strategy interface for resolving a {@link McpJsonMapper}. + */ +public interface McpJsonMapperSupplier extends Supplier { + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java new file mode 100644 index 000000000..ab37b43f3 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +/** + * Captures generic type information at runtime for parameterized JSON (de)serialization. + * Usage: TypeRef<List<Foo>> ref = new TypeRef<>(){}; + */ +public abstract class TypeRef { + + private final Type type; + + /** + * Constructs a new TypeRef instance, capturing the generic type information of the + * subclass. This constructor should be called from an anonymous subclass to capture + * the actual type arguments. For example:

    +	 * TypeRef<List<Foo>> ref = new TypeRef<>(){};
    +	 * 
    + * @throws IllegalStateException if TypeRef is not subclassed with actual type + * information + */ + protected TypeRef() { + Type superClass = getClass().getGenericSuperclass(); + if (superClass instanceof Class) { + throw new IllegalStateException("TypeRef constructed without actual type information"); + } + this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0]; + } + + /** + * Returns the captured type information. + * @return the Type representing the actual type argument captured by this TypeRef + * instance + */ + public Type getType() { + return type; + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java new file mode 100644 index 000000000..2497e7f80 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * Internal utility class for creating a default {@link JsonSchemaValidator} instance. + * This class uses the {@link ServiceLoader} to discover and instantiate a + * {@link JsonSchemaValidatorSupplier} implementation. + */ +final class JsonSchemaInternal { + + private static JsonSchemaValidator defaultValidator = null; + + /** + * Returns the default {@link JsonSchemaValidator} instance. If the default validator + * has not been initialized, it will be created using the {@link ServiceLoader} to + * discover and instantiate a {@link JsonSchemaValidatorSupplier} implementation. + * @return The default {@link JsonSchemaValidator} instance. + * @throws IllegalStateException If no {@link JsonSchemaValidatorSupplier} + * implementation exists on the classpath or if an error occurs during instantiation. + */ + static JsonSchemaValidator getDefaultValidator() { + if (defaultValidator == null) { + defaultValidator = JsonSchemaInternal.createDefaultValidator(); + } + return defaultValidator; + } + + /** + * Creates a default {@link JsonSchemaValidator} instance by loading a + * {@link JsonSchemaValidatorSupplier} implementation using the {@link ServiceLoader}. + * @return A default {@link JsonSchemaValidator} instance. + * @throws IllegalStateException If no {@link JsonSchemaValidatorSupplier} + * implementation is found or if an error occurs during instantiation. + */ + static JsonSchemaValidator createDefaultValidator() { + AtomicReference ex = new AtomicReference<>(); + return ServiceLoader.load(JsonSchemaValidatorSupplier.class).stream().flatMap(p -> { + try { + JsonSchemaValidatorSupplier supplier = p.get(); + return Stream.ofNullable(supplier); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).flatMap(jsonMapperSupplier -> { + try { + return Stream.of(jsonMapperSupplier.get()); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).findFirst().orElseThrow(() -> { + if (ex.get() != null) { + return ex.get(); + } + else { + return new IllegalStateException("No default JsonSchemaValidatorSupplier implementation found"); + } + }); + } + + private static void addException(AtomicReference ref, Exception toAdd) { + ref.updateAndGet(existing -> { + if (existing == null) { + return new IllegalStateException("Failed to initialize default JsonSchemaValidatorSupplier", toAdd); + } + else { + existing.addSuppressed(toAdd); + return existing; + } + }); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java new file mode 100644 index 000000000..8e35c0237 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.json.schema; + +import java.util.Map; + +/** + * Interface for validating structured content against a JSON schema. This interface + * defines a method to validate structured content based on the provided output schema. + * + * @author Christian Tzolov + */ +public interface JsonSchemaValidator { + + /** + * Represents the result of a validation operation. + * + * @param valid Indicates whether the validation was successful. + * @param errorMessage An error message if the validation failed, otherwise null. + * @param jsonStructuredOutput The text structured content in JSON format if the + * validation was successful, otherwise null. + */ + record ValidationResponse(boolean valid, String errorMessage, String jsonStructuredOutput) { + + public static ValidationResponse asValid(String jsonStructuredOutput) { + return new ValidationResponse(true, null, jsonStructuredOutput); + } + + public static ValidationResponse asInvalid(String message) { + return new ValidationResponse(false, message, null); + } + } + + /** + * Validates the structured content against the provided JSON schema. + * @param schema The JSON schema to validate against. + * @param structuredContent The structured content to validate. + * @return A ValidationResponse indicating whether the validation was successful or + * not. + */ + ValidationResponse validate(Map schema, Object structuredContent); + + /** + * Creates the default {@link JsonSchemaValidator}. + * @return The default {@link JsonSchemaValidator} + * @throws IllegalStateException If no {@link JsonSchemaValidator} implementation + * exists on the classpath. + */ + static JsonSchemaValidator createDefault() { + return JsonSchemaInternal.createDefaultValidator(); + } + + /** + * Returns the default {@link JsonSchemaValidator}. + * @return The default {@link JsonSchemaValidator} + * @throws IllegalStateException If no {@link JsonSchemaValidator} implementation + * exists on the classpath. + */ + static JsonSchemaValidator getDefault() { + return JsonSchemaInternal.getDefaultValidator(); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..6f69169a0 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java @@ -0,0 +1,19 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.function.Supplier; + +/** + * A supplier interface that provides a {@link JsonSchemaValidator} instance. + * Implementations of this interface are expected to return a new or cached instance of + * {@link JsonSchemaValidator} when {@link #get()} is invoked. + * + * @see JsonSchemaValidator + * @see Supplier + */ +public interface JsonSchemaValidatorSupplier extends Supplier { + +} diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 4d9f96e53..f1737a477 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,13 +6,13 @@ io.modelcontextprotocol.sdk mcp-parent - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT ../../pom.xml mcp-spring-webflux jar - WebFlux implementation of the Java MCP SSE transport - + WebFlux transports + WebFlux implementation for the SSE and Streamable Http Client and Server transports https://github.com/modelcontextprotocol/java-sdk @@ -22,16 +22,22 @@ - + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + + + io.modelcontextprotocol.sdk mcp - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT test @@ -82,6 +88,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test @@ -93,6 +105,12 @@ ${testcontainers.version} test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + org.awaitility @@ -115,6 +133,13 @@ test + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java new file mode 100644 index 000000000..0b5ce55cd --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -0,0 +1,616 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

    + * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

    + *

    + * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@link WebFluxSseClientTransport}. + *

    + * + * @author Dariusz Jędrzejczyk + * @see Streamable + * HTTP transport specification + */ +public class WebClientStreamableHttpTransport implements McpClientTransport { + + private static final String MISSING_SESSION_ID = "[missing_session_id]"; + + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { + }; + + private final McpJsonMapper jsonMapper; + + private final WebClient webClient; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference> activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private WebClientStreamableHttpTransport(McpJsonMapper jsonMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + this.supportedProtocolVersions = List.copyOf(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); + } + + @Override + public List protocolVersions() { + return supportedProtocolVersions; + } + + /** + * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} + * instances. + * @param webClientBuilder the {@link WebClient.Builder} to use + * @return a builder which will create an instance of + * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).then(); + } + return Mono.empty(); + }); + } + + private McpTransportSession createTransportSession() { + Function> onClose = sessionId -> sessionId == null ? Mono.empty() + : webClient.delete() + .uri(this.endpoint) + .header(HttpHeaders.MCP_SESSION_ID, sessionId) + .header(HttpHeaders.PROTOCOL_VERSION, this.latestSupportedProtocolVersion) + .retrieve() + .toBodilessEntity() + .onErrorComplete(e -> { + logger.warn("Got error when closing transport", e); + return true; + }) + .then(); + return new DefaultMcpTransportSession(onClose); + } + + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); + if (currentSession != null) { + return Mono.from(currentSession.closeGracefully()); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + return Mono.deferContextual(ctx -> { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + // Here we attempt to initialize the client. In case the server supports SSE, + // we will establish a long-running + // session here and listen for messages. If it doesn't, that's ok, the server + // is a simple, stateless one. + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add(HttpHeaders.LAST_EVENT_ID, id)); + } + }) + .exchangeToFlux(response -> { + if (isEventStream(response)) { + logger.debug("Established SSE stream via GET"); + return eventStream(stream, response); + } + else if (isNotAllowed(response)) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (isNotFound(response)) { + if (transportSession.sessionId().isPresent()) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return this.extractError(response, MISSING_SESSION_ID); + } + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.create(sink -> { + logger.debug("Sending message {}", message); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session + // here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = Flux.deferContextual(ctx -> webClient.post() + .uri(this.endpoint) + .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); + }) + .bodyValue(message) + .exchangeToFlux(response -> { + if (transportSession + .markInitialized(response.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + reconnect(null).contextWrite(sink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + // The spec mentions only ACCEPTED, but the existing SDKs can return + // 200 OK for notifications + if (response.statusCode().is2xxSuccessful()) { + Optional contentType = response.headers().contentType(); + long contentLength = response.headers().contentLength().orElse(-1); + // Existing SDKs consume notifications with no response body nor + // content type + if (contentType.isEmpty() || contentLength == 0 + || response.statusCode().equals(HttpStatus.ACCEPTED)) { + logger.trace("Message was successfully sent via POST for session {}", + sessionRepresentation); + // signal the caller that the message was successfully + // delivered + sink.success(); + // communicate to downstream there is no streamed data coming + return Flux.empty(); + } + else { + MediaType mediaType = contentType.get(); + if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + logger.debug("Established SSE stream via POST"); + // communicate to caller that the message was delivered + sink.success(); + // starting a stream + return newEventStream(response, sessionRepresentation); + } + else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", sessionRepresentation); + // communicate to caller the message was delivered + sink.success(); + return directResponseFlux(message, response); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + } + } + } + else { + if (isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) { + return mcpSessionNotFoundError(sessionRepresentation); + } + return this.extractError(response, sessionRepresentation); + } + })) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorComplete(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + sink.error(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static Flux mcpSessionNotFoundError(String sessionRepresentation) { + logger.warn("Session {} was not found on the MCP server", sessionRepresentation); + // inform the stream/connection subscriber + return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); + } + + private Flux extractError(ClientResponse response, String sessionRepresentation) { + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = jsonMapper.readValue(body, McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) + : new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse); + } + catch (IOException ex) { + toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e); + logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + if (!sessionRepresentation.equals(MISSING_SESSION_ID)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.error(new McpTransportException("Received 400 BAD REQUEST for session " + + sessionRepresentation + ". " + toPropagate.getMessage(), toPropagate)); + } + return Mono.error(toPropagate); + }).flux(); + } + + private Flux eventStream(McpTransportStream stream, ClientResponse response) { + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); + return Flux.from(sessionStream.consumeSseStream(idWithMessages)); + } + + private static boolean isNotFound(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); + } + + private static boolean isNotAllowed(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); + } + + private static boolean isEventStream(ClientResponse response) { + return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse(MISSING_SESSION_ID); + } + + private Flux directResponseFlux(McpSchema.JSONRPCMessage sentMessage, + ClientResponse response) { + return response.bodyToMono(String.class).>handle((responseMessage, s) -> { + try { + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(responseMessage) ? responseMessage : "[empty]"); + s.complete(); + } + else { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(jsonMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } + } + catch (IOException e) { + s.error(new McpTransportException(e)); + } + }).flatMapIterable(Function.identity()); + } + + private Flux newEventStream(ClientResponse response, String sessionRepresentation) { + McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + sessionRepresentation); + return eventStream(sessionStream, response); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); + } + + private Tuple2, Iterable> parse(ServerSentEvent event) { + if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + // We don't support batching ATM and probably won't since the next version + // considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, event.data()); + return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); + } + catch (IOException ioException) { + throw new McpTransportException("Error parsing JSON-RPC message: " + event.data(), ioException); + } + } + else { + logger.debug("Received SSE event with type: {}", event); + return Tuples.of(Optional.empty(), List.of()); + } + } + + /** + * Builder for {@link WebClientStreamableHttpTransport}. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private WebClient.Builder webClientBuilder; + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + + private Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Configure the {@link McpJsonMapper} to use. + * @param jsonMapper instance to use + * @return the builder instance + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Configure the {@link WebClient.Builder} to construct the {@link WebClient}. + * @param webClientBuilder instance to use + * @return the builder instance + */ + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

    + * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + + /** + * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link WebClientStreamableHttpTransport} + */ + public WebClientStreamableHttpTransport build() { + return new WebClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + webClientBuilder, endpoint, resumableStreams, openConnectionOnStartup, supportedProtocolVersions); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 8ea65fd78..91b89d6d2 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -1,18 +1,22 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.client.transport; import java.io.IOException; +import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,10 +62,12 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements ClientMcpTransport { +public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; + /** * Event type for JSON-RPC messages received through the SSE connection. The server * sends messages with this event type to transmit JSON-RPC protocol data. @@ -79,7 +85,7 @@ public class WebFluxSseClientTransport implements ClientMcpTransport { * Default SSE endpoint path as specified by the MCP transport specification. This * endpoint is used to establish the SSE connection with the server. */ - private static final String SSE_ENDPOINT = "/sse"; + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; /** * Type reference for parsing SSE events containing string data. @@ -94,10 +100,10 @@ public class WebFluxSseClientTransport implements ClientMcpTransport { private final WebClient webClient; /** - * ObjectMapper for serializing outbound messages and deserializing inbound messages. + * JSON mapper for serializing outbound messages and deserializing inbound messages. * Handles conversion between JSON-RPC messages and their string representation. */ - protected ObjectMapper objectMapper; + protected McpJsonMapper jsonMapper; /** * Subscription for the SSE connection handling inbound messages. Used for cleanup @@ -118,14 +124,21 @@ public class WebFluxSseClientTransport implements ClientMcpTransport { protected final Sinks.One messageEndpointSink = Sinks.one(); /** - * Constructs a new SseClientTransport with the specified WebClient builder. Uses a - * default ObjectMapper instance for JSON processing. + * The SSE endpoint URI provided by the server. Used for sending outbound messages via + * HTTP POST requests. + */ + private String sseEndpoint; + + /** + * Constructs a new SseClientTransport with the specified WebClient builder and + * ObjectMapper. Initializes both inbound and outbound message processing pipelines. * @param webClientBuilder the WebClient.Builder to use for creating the WebClient * instance - * @throws IllegalArgumentException if webClientBuilder is null + * @param jsonMapper the ObjectMapper to use for JSON processing + * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { - this(webClientBuilder, new ObjectMapper()); + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { + this(webClientBuilder, jsonMapper, DEFAULT_SSE_ENDPOINT); } /** @@ -133,15 +146,23 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { * ObjectMapper. Initializes both inbound and outbound message processing pipelines. * @param webClientBuilder the WebClient.Builder to use for creating the WebClient * instance - * @param objectMapper the ObjectMapper to use for JSON processing + * @param jsonMapper the ObjectMapper to use for JSON processing + * @param sseEndpoint the SSE endpoint URI to use for establishing the connection * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.webClient = webClientBuilder.build(); + this.sseEndpoint = sseEndpoint; + } + + @Override + public List protocolVersions() { + return List.of(MCP_PROTOCOL_VERSION); } /** @@ -163,11 +184,12 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe * @param handler a function that processes incoming JSON-RPC messages and returns * responses * @return a Mono that completes when the connection is fully established - * @throws McpError if there's an error processing SSE events or if an unrecognized - * event type is received */ @Override public Mono connect(Function, Mono> handler) { + // TODO: Avoid eager connection opening and enable resilience + // -> upon disconnects, re-establish connection + // -> allow optimizing for eager connection start using a constructor flag Flux> events = eventStream(); this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { @@ -178,12 +200,12 @@ public Mono connect(Function, Mono> h else { // TODO: clarify with the spec if multiple events can be // received - s.error(new McpError("Failed to handle SSE endpoint event")); + s.error(new RuntimeException("Failed to handle SSE endpoint event")); } } else if (MESSAGE_EVENT_TYPE.equals(event.event())) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, event.data()); s.next(message); } catch (IOException ioException) { @@ -191,7 +213,8 @@ else if (MESSAGE_EVENT_TYPE.equals(event.event())) { } } else { - s.error(new McpError("Received unrecognized SSE event type: " + event.event())); + logger.debug("Received unrecognized SSE event type: {}", event); + s.complete(); } }).transform(handler)).subscribe(); @@ -220,10 +243,11 @@ public Mono sendMessage(JSONRPCMessage message) { return Mono.empty(); } try { - String jsonText = this.objectMapper.writeValueAsString(message); + String jsonText = this.jsonMapper.writeValueAsString(message); return webClient.post() .uri(messageEndpointUri) .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .bodyValue(jsonText) .retrieve() .toBodilessEntity() @@ -254,8 +278,9 @@ public Mono sendMessage(JSONRPCMessage message) { protected Flux> eventStream() {// @formatter:off return this.webClient .get() - .uri(SSE_ENDPOINT) + .uri(this.sseEndpoint) .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .retrieve() .bodyToFlux(SSE_TYPE) .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); @@ -312,13 +337,76 @@ public Mono closeGracefully() { // @formatter:off * type conversion capabilities to handle complex object structures. * @param the target type to convert the data into * @param data the source object to convert - * @param typeRef the TypeReference describing the target type + * @param typeRef the TypeRef describing the target type * @return the unmarshalled object of type T * @throws IllegalArgumentException if the conversion cannot be performed */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); + } + + /** + * Creates a new builder for {@link WebFluxSseClientTransport}. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @return a new builder instance + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + /** + * Builder for {@link WebFluxSseClientTransport}. + */ + public static class Builder { + + private final WebClient.Builder webClientBuilder; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private McpJsonMapper jsonMapper; + + /** + * Creates a new builder with the specified WebClient.Builder. + * @param webClientBuilder the WebClient.Builder to use + */ + public Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the JSON mapper for serialization/deserialization. + * @param jsonMapper the JsonMapper to use + * @return this builder + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Builds a new {@link WebFluxSseClientTransport} instance. + * @return a new transport instance + */ + public WebFluxSseClientTransport build() { + return new WebFluxSseClientTransport(webClientBuilder, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, sseEndpoint); + } + } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java deleted file mode 100644 index bed7293ee..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ /dev/null @@ -1,410 +0,0 @@ -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; - -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.server.RouterFunction; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; -import org.springframework.web.reactive.function.server.ServerResponse; - -/** - * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using - * Server-Sent Events (SSE). This implementation provides a bidirectional communication - * channel between MCP clients and servers using HTTP POST for client-to-server messages - * and SSE for server-to-client messages. - * - *

    - * Key features: - *

      - *
    • Implements the {@link ServerMcpTransport} interface for MCP server transport - * functionality
    • - *
    • Uses WebFlux for non-blocking request handling and SSE support
    • - *
    • Maintains client sessions for reliable message delivery
    • - *
    • Supports graceful shutdown with session cleanup
    • - *
    • Thread-safe message broadcasting to multiple clients
    • - *
    - * - *

    - * The transport sets up two main endpoints: - *

      - *
    • SSE endpoint (/sse) - For establishing SSE connections with clients
    • - *
    • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
    • - *
    - * - *

    - * This implementation is thread-safe and can handle multiple concurrent client - * connections. It uses {@link ConcurrentHashMap} for session management and Reactor's - * {@link Sinks} for thread-safe message broadcasting. - * - * @author Christian Tzolov - * @author Alexandros Pappas - * @see ServerMcpTransport - * @see ServerSentEvent - */ -public class WebFluxSseServerTransport implements ServerMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. - */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - private final ObjectMapper objectMapper; - - private final String messageEndpoint; - - private final String sseEndpoint; - - private final RouterFunction routerFunction; - - /** - * Map of active client sessions, keyed by session ID. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - private Function, Mono> connectHandler; - - /** - * Constructs a new WebFlux SSE server transport instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - - this.objectMapper = objectMapper; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) - .build(); - } - - /** - * Constructs a new WebFlux SSE server transport instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Configures the message handler for this transport. In the WebFlux SSE - * implementation, this method stores the handler for processing incoming messages but - * doesn't establish any connections since the server accepts connections rather than - * initiating them. - * @param handler A function that processes incoming JSON-RPC messages and returns - * responses. This handler will be called for each message received through the - * message endpoint. - * @return An empty Mono since the server doesn't initiate connections - */ - @Override - public Mono connect(Function, Mono> handler) { - this.connectHandler = handler; - // Server-side transport doesn't initiate connections - return Mono.empty().then(); - } - - /** - * Broadcasts a JSON-RPC message to all connected clients through their SSE - * connections. The message is serialized to JSON and sent as a server-sent event to - * each active session. - * - *

    - * The method: - *

      - *
    • Serializes the message to JSON
    • - *
    • Creates a server-sent event with the message data
    • - *
    • Attempts to send the event to all active sessions
    • - *
    • Tracks and reports any delivery failures
    • - *
    - * @param message The JSON-RPC message to broadcast - * @return A Mono that completes when the message has been sent to all sessions, or - * errors if any session fails to receive the message - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - return Mono.create(sink -> { - try {// @formatter:off - String jsonText = objectMapper.writeValueAsString(message); - ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - List failedSessions = sessions.values().stream() - .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) - .map(session -> session.id) - .toList(); - - if (failedSessions.isEmpty()) { - logger.debug("Successfully broadcast message to all sessions"); - sink.success(); - } - else { - String error = "Failed to broadcast message to sessions: " + String.join(", ", failedSessions); - logger.error(error); - sink.error(new RuntimeException(error)); - } // @formatter:on - } - catch (IOException e) { - logger.error("Failed to serialize message: {}", e.getMessage()); - sink.error(e); - } - }); - } - - /** - * Converts data from one type to another using the configured ObjectMapper. This - * method is primarily used for converting between different representations of - * JSON-RPC message data. - * @param The target type to convert to - * @param data The source data to convert - * @param typeRef Type reference describing the target type - * @return The converted data - * @throws IllegalArgumentException if the conversion fails - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. This method ensures all active - * sessions are properly closed and cleaned up. - * - *

    - * The shutdown process: - *

      - *
    • Marks the transport as closing to prevent new connections
    • - *
    • Closes each active session
    • - *
    • Removes closed sessions from the sessions map
    • - *
    • Times out after 5 seconds if shutdown takes too long
    • - *
    - * @return A Mono that completes when all sessions have been closed - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }).then(Mono.when(sessions.values().stream().map(session -> { - String sessionId = session.id; - return Mono.fromRunnable(() -> session.close()) - .then(Mono.delay(Duration.ofMillis(100))) - .then(Mono.fromRunnable(() -> sessions.remove(sessionId))); - }).toList())) - .timeout(Duration.ofSeconds(5)) - .doOnSuccess(v -> logger.debug("Graceful shutdown completed")) - .doOnError(e -> logger.error("Error during graceful shutdown: {}", e.getMessage())); - } - - /** - * Returns the WebFlux router function that defines the transport's HTTP endpoints. - * This router function should be integrated into the application's web configuration. - * - *

    - * The router function defines two endpoints: - *

      - *
    • GET {sseEndpoint} - For establishing SSE connections
    • - *
    • POST {messageEndpoint} - For receiving client messages
    • - *
    - * @return The configured {@link RouterFunction} for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - /** - * Handles new SSE connection requests from clients. Creates a new session for each - * connection and sets up the SSE event stream. - * - *

    - * The handler performs the following steps: - *

      - *
    • Generates a unique session ID
    • - *
    • Creates a new ClientSession instance
    • - *
    • Sends the message endpoint URI as an initial event
    • - *
    • Sets up message forwarding for the session
    • - *
    • Handles connection cleanup on completion or errors
    • - *
    - * @param request The incoming server request - * @return A response with the SSE event stream - */ - private Mono handleSseConnection(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - ClientSession session = new ClientSession(sessionId); - this.sessions.put(sessionId, session); - - return ServerResponse.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - // Send initial endpoint event - logger.debug("Sending initial endpoint event to session: {}", sessionId); - sink.next(ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(messageEndpoint).build()); - - // Subscribe to session messages - session.messageSink.asFlux() - .doOnSubscribe(s -> logger.debug("Session {} subscribed to message sink", sessionId)) - .doOnComplete(() -> { - logger.debug("Session {} completed", sessionId); - sessions.remove(sessionId); - }) - .doOnError(error -> { - logger.error("Error in session {}: {}", sessionId, error.getMessage()); - sessions.remove(sessionId); - }) - .doOnCancel(() -> { - logger.debug("Session {} cancelled", sessionId); - sessions.remove(sessionId); - }) - .subscribe(event -> { - logger.debug("Forwarding event to session {}: {}", sessionId, event); - sink.next(event); - }, sink::error, sink::complete); - - sink.onCancel(() -> { - logger.debug("Session {} cancelled", sessionId); - sessions.remove(sessionId); - }); - }), ServerSentEvent.class); - } - - /** - * Handles incoming JSON-RPC messages from clients. Deserializes the message and - * processes it through the configured message handler. - * - *

    - * The handler: - *

      - *
    • Deserializes the incoming JSON-RPC message
    • - *
    • Passes it through the message handler chain
    • - *
    • Returns appropriate HTTP responses based on processing results
    • - *
    • Handles various error conditions with appropriate error responses
    • - *
    - * @param request The incoming server request containing the JSON-RPC message - * @return A response indicating the message processing result - */ - private Mono handleMessage(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - return request.bodyToMono(String.class).flatMap(body -> { - try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return Mono.just(message) - .transform(this.connectHandler) - .flatMap(response -> ServerResponse.ok().build()) - .onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); - } - }); - } - - /** - * Represents an active client SSE connection session. Manages the message sink for - * sending events to the client and handles session lifecycle. - * - *

    - * Each session: - *

      - *
    • Has a unique identifier
    • - *
    • Maintains its own message sink for event broadcasting
    • - *
    • Supports clean shutdown through the close method
    • - *
    - */ - private static class ClientSession { - - private final String id; - - private final Sinks.Many> messageSink; - - ClientSession(String id) { - this.id = id; - logger.debug("Creating new session: {}", id); - this.messageSink = Sinks.many().replay().latest(); - logger.debug("Session {} initialized with replay sink", id); - } - - void close() { - logger.debug("Closing session: {}", id); - Sinks.EmitResult result = messageSink.tryEmitComplete(); - if (result.isFailure()) { - logger.warn("Failed to complete message sink for session {}: {}", id, result); - } - else { - logger.debug("Successfully completed message sink for session {}", id); - } - } - - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java new file mode 100644 index 000000000..0c80c5b8b --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -0,0 +1,530 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using + * Server-Sent Events (SSE). This implementation provides a bidirectional communication + * channel between MCP clients and servers using HTTP POST for client-to-server messages + * and SSE for server-to-client messages. + * + *

    + * Key features: + *

      + *
    • Implements the {@link McpServerTransportProvider} interface that allows managing + * {@link McpServerSession} instances and enabling their communication with the + * {@link McpServerTransport} abstraction.
    • + *
    • Uses WebFlux for non-blocking request handling and SSE support
    • + *
    • Maintains client sessions for reliable message delivery
    • + *
    • Supports graceful shutdown with session cleanup
    • + *
    • Thread-safe message broadcasting to multiple clients
    • + *
    + * + *

    + * The transport sets up two main endpoints: + *

      + *
    • SSE endpoint (/sse) - For establishing SSE connections with clients
    • + *
    • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
    • + *
    + * + *

    + * This implementation is thread-safe and can handle multiple concurrent client + * connections. It uses {@link ConcurrentHashMap} for session management and Project + * Reactor's non-blocking APIs for message processing and delivery. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @author Dariusz Jędrzejczyk + * @see McpServerTransport + * @see ServerSentEvent + */ +public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + private static final String MCP_PROTOCOL_VERSION = "2025-06-18"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + public static final String SESSION_ID = "sessionId"; + + public static final String DEFAULT_BASE_URL = ""; + + private final McpJsonMapper jsonMapper; + + /** + * Base URL for the message endpoint. This is used to construct the full URL for + * clients to send their JSON-RPC messages. + */ + private final String baseUrl; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param jsonMapper The ObjectMapper to use for JSON serialization/deserialization of + * MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param sseEndpoint The SSE endpoint path. Must not be null. + * @param keepAliveInterval The interval for sending keep-alive pings to clients. + * @param contextExtractor The context extractor to use for extracting MCP transport + * context from HTTP requests. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "ObjectMapper must not be null"); + Assert.notNull(baseUrl, "Message base path must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.baseUrl = baseUrl; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

    + * The method: + *

      + *
    • Serializes the message to JSON
    • + *
    • Creates a server-sent event with the message data
    • + *
    • Attempts to send the event to all active sessions
    • + *
    • Tracks and reports any delivery failures
    • + *
    + * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + // FIXME: This javadoc makes claims about using isClosing flag but it's not + // actually + // doing that. + + /** + * Initiates a graceful shutdown of all the sessions. This method ensures all active + * sessions are properly closed and cleaned up. + * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpServerSession::closeGracefully) + .then() + .doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

    + * The router function defines two endpoints: + *

      + *
    • GET {sseEndpoint} - For establishing SSE connections
    • + *
    • POST {messageEndpoint} - For receiving client messages
    • + *
    + * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleSseConnection(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); + + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + + logger.debug("Created new SSE connection for session: {}", sessionId); + sessions.put(sessionId, session); + + // Send initial endpoint event + logger.debug("Sending initial endpoint event to session: {}", sessionId); + sink.next( + ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId)).build()); + sink.onCancel(() -> { + logger.debug("Session {} cancelled", sessionId); + sessions.remove(sessionId); + }); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + return UriComponentsBuilder.fromUriString(this.baseUrl) + .path(this.messageEndpoint) + .queryParam(SESSION_ID, sessionId) + .build() + .toUriString(); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

    + * The handler: + *

      + *
    • Deserializes the incoming JSON-RPC message
    • + *
    • Passes it through the message handler chain
    • + *
    • Returns appropriate HTTP responses based on processing results
    • + *
    • Handles various error conditions with appropriate error responses
    • + *
    + * @param request The incoming server request containing the JSON-RPC message + * @return A Mono emitting the response indicating the message processing result + */ + private Mono handleMessage(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + if (request.queryParam("sessionId").isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); + } + + McpServerSession session = sessions.get(request.queryParam("sessionId").get()); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private class WebFluxMcpSessionTransport implements McpServerTransport { + + private final FluxSink> sink; + + public WebFluxMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromSupplier(() -> { + try { + return jsonMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxSseServerTransportProvider}. + *

    + * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String baseUrl = DEFAULT_BASE_URL; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private Duration keepAliveInterval; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + /** + * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The McpJsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the project basePath as endpoint prefix where clients should send their + * JSON-RPC messages + * @param baseUrl the message basePath . Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if basePath is null + */ + public Builder basePath(String baseUrl) { + Assert.notNull(baseUrl, "basePath must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint The SSE endpoint path. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if sseEndpoint is null + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the interval for sending keep-alive pings to clients. + * @param keepAliveInterval The keep-alive interval duration. If null, keep-alive + * is disabled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the + * configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxSseServerTransportProvider build() { + Assert.notNull(messageEndpoint, "Message endpoint must be set"); + return new WebFluxSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java new file mode 100644 index 000000000..400be341e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java @@ -0,0 +1,222 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebFlux based {@link McpStatelessServerTransport}. + * + * @author Dariusz Jędrzejczyk + */ +public class WebFluxStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class); + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

    + * The router function defines one endpoint handling two HTTP methods: + *

      + *
    • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
    • + *
    • POST {messageEndpoint} - For handling client requests and notifications
    • + *
    + * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private Mono handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest).flatMap(jsonrpcResponse -> { + try { + String json = jsonMapper.writeValueAsString(jsonrpcResponse); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).bodyValue(json); + } + catch (IOException e) { + logger.error("Failed to serialize response: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError("Failed to serialize response")); + } + }); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .then(ServerResponse.accepted().build()); + } + else { + return ServerResponse.badRequest() + .bodyValue(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStatelessServerTransport}. + *

    + * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Builder() { + // used by a static method + } + + /** + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The JsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStatelessServerTransport} with the + * configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStatelessServerTransport build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new WebFluxStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java new file mode 100644 index 000000000..deebfc616 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -0,0 +1,495 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementation of a WebFlux based {@link McpStreamableServerTransportProvider}. + * + * @author Dariusz Jędrzejczyk + */ +public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class); + + public static final String MESSAGE_EVENT_TYPE = "message"; + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private final boolean disallowDelete; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor, boolean disallowDelete, + Duration keepAliveInterval) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.disallowDelete = disallowDelete; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.isClosing = true; + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpStreamableServerSession::closeGracefully) + .then(); + }).then().doOnSuccess(v -> { + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

    + * The router function defines one endpoint with three methods: + *

      + *
    • GET {messageEndpoint} - For the client listening SSE stream
    • + *
    • POST {messageEndpoint} - For receiving client messages
    • + *
    • DELETE {messageEndpoint} - For removing sessions
    • + *
    + * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Opens the listening SSE streams for clients. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleGet(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + return Mono.defer(() -> { + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().build(); + } + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + if (!request.headers().header(HttpHeaders.LAST_EVENT_ID).isEmpty()) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), + ServerSentEvent.class); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( + sink); + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + sink.onDispose(listeningStream::close); + // TODO Clarify why the outer context is not present in the + // Flux.create sink? + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); + + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Handles incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A Mono with the response appropriate to a particular Streamable HTTP flow. + */ + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + var typeReference = new TypeRef() { + }; + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + typeReference); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + sessions.put(init.session().getId(), init.session()); + return init.initResult().map(initializeResult -> { + McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); + try { + return this.jsonMapper.writeValueAsString(jsonrpcResponse); + } + catch (IOException e) { + logger.warn("Failed to serialize initResponse", e); + throw Exceptions.propagate(e); + } + }) + .flatMap(initResult -> ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .bodyValue(initResult)); + } + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); + Mono stream = session.responseStream(jsonrpcRequest, st); + Disposable streamSubscription = stream.onErrorComplete(err -> { + sink.error(err); + return true; + }).contextWrite(sink.contextView()).subscribe(); + sink.onCancel(streamSubscription); + // TODO Clarify why the outer context is not present in the + // Flux.create sink? + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), + ServerSentEvent.class); + } + else { + return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }) + .switchIfEmpty(ServerResponse.badRequest().build()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private Mono handleDelete(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + return Mono.defer(() -> { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + return session.delete().then(ServerResponse.ok().build()); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private class WebFluxStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final FluxSink> sink; + + public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return this.sendMessage(message, null); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromSupplier(() -> { + try { + return jsonMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .id(messageId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStreamableServerTransportProvider}. + *

    + * This builder provides a fluent API for configuring and creating instances of + * WebFluxStreamableServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private boolean disallowDelete; + + private Duration keepAliveInterval; + + private Builder() { + // used by a static method + } + + /** + * Sets the {@link McpJsonMapper} to use for JSON serialization/deserialization of + * MCP messages. + * @param jsonMapper The {@link McpJsonMapper} instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets whether the session removal capability is disabled. + * @param disallowDelete if {@code true}, the DELETE endpoint will not be + * supported and sessions won't be deleted. + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the keep-alive interval for the server transport. + * @param keepAliveInterval The interval for sending keep-alive messages. If null, + * no keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebFluxStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStreamableServerTransportProvider build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new WebFluxStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, contextExtractor, + disallowDelete, keepAliveInterval); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 4cd24c621..eb8abb90c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -1,462 +1,106 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol; import java.time.Duration; -import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; +import java.util.stream.Stream; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.awaitility.Awaitility.await; - -public class WebFluxSseIntegrationTests { - - private static final int PORT = 8182; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxSseServerTransport mcpServerTransport; - - ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); - - @BeforeEach - public void before() { - - this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); - clientBulders.put("webflux", - McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); - - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var clientBuilder = clientBulders.get(clientType); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) throws InterruptedException { - - var clientBuilder = clientBulders.get(clientType); - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); +@Timeout(15) +class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { - mcpClient.close(); - mcpServer.close(); - } + private static final int PORT = TestUtil.findAvailablePort(); - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - var clientBuilder = clientBulders.get(clientType); + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No - // roots - // capability - .build(); + private DisposableServer httpServer; - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + private WebFluxSseServerTransportProvider mcpServerTransportProvider; - // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); - mcpClient.close(); - mcpServer.close(); + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithEmptyRootsList(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build(); + @Override + protected void prepareClients(int port, String mcpEndpoint) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); - mcpClient.rootsListChangedNotification(); + clientBuilders.put("webflux", + McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()) + .requestTimeout(Duration.ofHours(10))); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - mcpClient.close(); - mcpServer.close(); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleConsumers(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - - mcpClient.close(); - mcpServer.close(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBulders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Close server while subscription is active - mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBulders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); + @BeforeEach + public void before() { - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .build(); - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - mcpClient.close(); - mcpServer.close(); + prepareClients(PORT, null); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBulders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - - mcpClient.close(); - mcpServer.close(); + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java new file mode 100644 index 000000000..96a786a9e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.stream.Stream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +@Timeout(15) +class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStatelessServerTransport mcpStreamableServerTransport; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + clientBuilders + .put("webflux", McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + } + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpStreamableServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpStreamableServerTransport); + } + + @BeforeEach + public void before() { + this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + prepareClients(PORT, null); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..5d2bfda68 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.utils.McpTestRequestRecordingExchangeFilterFunction; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +class WebFluxStreamableHttpVersionNegotiationIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private DisposableServer httpServer; + + private final McpTestRequestRecordingExchangeFilterFunction recordingFilterFunction = new McpTestRequestRecordingExchangeFilterFunction(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> new McpSchema.CallToolResult( + exchange.transportContext().get("protocol-version").toString(), null); + + private final WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(req -> McpTransportContext + .create(Map.of("protocol-version", req.headers().firstHeader("MCP-protocol-version")))) + .build(); + + private final McpSyncServer mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler)) + .build(); + + @BeforeEach + void setUp() { + RouterFunction filteredRouter = mcpStreamableServerTransportProvider.getRouterFunction() + .filter(recordingFilterFunction); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(filteredRouter); + + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + if (mcpServer != null) { + mcpServer.close(); + } + } + + @Test + void usesLatestVersion() { + var client = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .build()) + .requestTimeout(Duration.ofHours(10)) + .build(); + + client.initialize(); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = recordingFilterFunction.getCalls(); + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingExchangeFilterFunction.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + var transport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).requestTimeout(Duration.ofHours(10)).build(); + + client.initialize(); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = recordingFilterFunction.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls) + .filteredOn(c -> !c.body().contains("\"method\":\"initialize\"") && c.method().equals(HttpMethod.POST)) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingExchangeFilterFunction.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java new file mode 100644 index 000000000..5ab651931 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +@Timeout(15) +class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpStreamableServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpStreamableServerTransportProvider); + } + + @BeforeEach + public void before() { + + this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); + + HttpHandler httpHandler = RouterFunctions + .toHttpHandler(mcpStreamableServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + prepareClients(PORT, null); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..191f10376 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,20 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..cf4458506 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..f47ba5277 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 6cd74631e..72c0168d5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -6,13 +6,15 @@ import java.time.Duration; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; /** * Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}. @@ -24,33 +26,32 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); + .waitingFor(Wait.forHttp("/").forStatusCode(404).forPort(3001)); @Override - protected ClientMcpTransport createMcpTransport() { - return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); + protected McpClientTransport createMcpTransport() { + return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 6b980da41..b483029e0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -7,11 +7,12 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; - import org.springframework.web.reactive.function.client.WebClient; /** @@ -24,33 +25,32 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { - return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); + protected McpClientTransport createMcpTransport() { + return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java new file mode 100644 index 000000000..214fa489b --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java @@ -0,0 +1,403 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for error handling in WebClientStreamableHttpTransport. Addresses concurrency + * issues with proper Reactor patterns. + * + * @author Christian Tzolov + */ +@Timeout(15) +public class WebClientStreamableHttpTransportErrorHandlingTest { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + // Initialize latches for proper request synchronization + CountDownLatch firstRequestLatch; + + CountDownLatch secondRequestLatch; + + CountDownLatch getRequestLatch; + + @BeforeEach + void startServer() throws IOException { + + // Initialize latches for proper synchronization + firstRequestLatch = new CountDownLatch(1); + secondRequestLatch = new CountDownLatch(1); + getRequestLatch = new CountDownLatch(1); + + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("GET".equals(method)) { + // This is the SSE connection attempt after session establishment + getRequestLatch.countDown(); + // Return 405 Method Not Allowed to indicate SSE not supported + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Track which request this is + if (firstRequestLatch.getCount() > 0) { + // // First request - should have no session ID + firstRequestLatch.countDown(); + } + else if (secondRequestLatch.getCount() > 0) { + // Second request - should have session ID + secondRequestLatch.countDown(); + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Don't include session ID in 404 and 400 responses - the implementation + // checks if the transport has a session stored locally + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null && status == 200) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + else { + exchange.sendResponseHeaders(status, 0); + } + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 404 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test404WithoutSessionId() { + serverResponseStatus.set(404); + currentServerSessionId.set(null); // No session ID in response + + var testMessage = createTestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(Duration.ofSeconds(5)); + } + + /** + * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException + * Fixed version using proper async coordination + */ + @Test + void test404WithSessionId() throws InterruptedException { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-123"); + + // Set up exception handler to verify session invalidation + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Send first message to establish session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Wait for first request to complete + assertThat(firstRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Wait for the GET request (SSE connection attempt) to complete + assertThat(getRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Now return 404 for next request + serverResponseStatus.set(404); + + // Use delaySubscription to ensure session is fully processed before next + // request + StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) + .expectError(McpTransportSessionNotFoundException.class) + .verify(Duration.ofSeconds(5)); + + // Wait for second request to be made + assertThat(secondRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Verify the second request included the session ID + assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-123"); + + // Verify exception handler was called with SessionNotFoundException using + // timeout + verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); + } + + /** + * Test that 400 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test400WithoutSessionId() { + serverResponseStatus.set(400); + currentServerSessionId.set(null); // No session ID + + var testMessage = createTestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(Duration.ofSeconds(5)); + } + + /** + * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException + * Fixed version using proper async coordination + */ + @Test + void test400WithSessionId() throws InterruptedException { + + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-456"); + + // Set up exception handler + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Send first message to establish session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Wait for first request to complete + boolean firstCompleted = firstRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(firstCompleted).isTrue(); + + // Wait for the GET request (SSE connection attempt) to complete + boolean getCompleted = getRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(getCompleted).isTrue(); + + // Now return 400 for next request (simulating unknown session ID) + serverResponseStatus.set(400); + + // Use delaySubscription to ensure session is fully processed before next + // request + StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) + .expectError(McpTransportSessionNotFoundException.class) + .verify(Duration.ofSeconds(5)); + + // Wait for second request to be made + boolean secondCompleted = secondRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(secondCompleted).isTrue(); + + // Verify the second request included the session ID + assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-456"); + + // Verify exception handler was called with timeout + verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); + } + + /** + * Test session recovery after SessionNotFoundException Fixed version using reactive + * patterns and proper synchronization + */ + @Test + void testSessionRecoveryAfter404() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("session-1"); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Use Mono.defer to ensure proper sequencing + Mono establishSession = transport.sendMessage(testMessage).then(Mono.defer(() -> { + // Simulate session loss - return 404 + serverResponseStatus.set(404); + return transport.sendMessage(testMessage).onErrorResume(McpTransportSessionNotFoundException.class, e -> { + // Expected error, continue with recovery + return Mono.empty(); + }); + })).then(Mono.defer(() -> { + // Now server is back with new session + serverResponseStatus.set(200); + currentServerSessionId.set("session-2"); + lastReceivedSessionId.set(null); // Reset to verify new session + + // Should be able to establish new session + return transport.sendMessage(testMessage); + })).then(Mono.defer(() -> { + // Verify no session ID was sent (since old session was invalidated) + assertThat(lastReceivedSessionId.get()).isNull(); + + // Next request should use the new session ID + return transport.sendMessage(testMessage); + })).doOnSuccess(v -> { + // Session ID should now be sent with requests + assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); + }); + + StepVerifier.create(establishSession).verifyComplete(); + } + + /** + * Test that reconnect (GET request) also properly handles 404/400 errors Fixed + * version with proper async handling + */ + @Test + void testReconnectErrorHandling() throws InterruptedException { + // Initialize latch for SSE connection + CountDownLatch sseConnectionLatch = new CountDownLatch(1); + + // Set up SSE endpoint for GET requests + server.createContext("/mcp-sse", exchange -> { + String method = exchange.getRequestMethod(); + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if ("GET".equals(method)) { + sseConnectionLatch.countDown(); + int status = serverResponseStatus.get(); + + if (status == 404 && requestSessionId != null) { + // 404 with session ID - should trigger SessionNotFoundException + exchange.sendResponseHeaders(404, 0); + } + else if (status == 404) { + // 404 without session ID - should trigger McpTransportException + exchange.sendResponseHeaders(404, 0); + } + else { + // Normal SSE response + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + // Send a test SSE event + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + exchange.getResponseBody().write(sseData.getBytes()); + } + } + else { + // POST request handling + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + exchange.close(); + }); + + // Test with session ID - should get SessionNotFoundException + serverResponseStatus.set(200); + currentServerSessionId.set("sse-session-1"); + + var transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)) + .endpoint("/mcp-sse") + .openConnectionOnStartup(true) // This will trigger GET request on connect + .build(); + + // First connect successfully + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Wait for SSE connection to be established + boolean connected = sseConnectionLatch.await(5, TimeUnit.SECONDS); + assertThat(connected).isTrue(); + + // Send message to establish session + var testMessage = createTestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Clean up + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + private McpSchema.JSONRPCRequest createTestMessage() { + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0")); + return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", + initializeRequest); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..34e422be4 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import org.springframework.web.reactive.function.client.WebClient; + +class WebClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + static WebClient.Builder builder; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + builder = WebClient.builder().baseUrl(host); + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void testCloseUninitialized() { + var transport = WebClientStreamableHttpTransport.builder(builder).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = WebClientStreamableHttpTransport.builder(builder).build(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 912e04f14..a29c9d69c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -6,13 +6,18 @@ import java.time.Duration; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -26,6 +31,7 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.function.client.WebClient; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -41,7 +47,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -50,8 +57,6 @@ class WebFluxSseClientTransportTests { private WebClient.Builder webClientBuilder; - private ObjectMapper objectMapper; - // Test class to access protected methods static class TestSseClientTransport extends WebFluxSseClientTransport { @@ -59,17 +64,10 @@ static class TestSseClientTransport extends WebFluxSseClientTransport { private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - super(webClientBuilder, objectMapper); + public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { + super(webClientBuilder, jsonMapper); } - // @Override - // public Mono connect(Function, - // Mono> handler) { - // simulateEndpointEvent("https://localhost:3001"); - // return super.connect(handler); - // } - @Override protected Flux> eventStream() { return super.eventStream().mergeWith(events.asFlux()); @@ -83,6 +81,11 @@ public int getInboundMessageCount() { return inboundMessageCount.get(); } + public void simulateSseComment(String comment) { + events.tryEmitNext(ServerSentEvent.builder().comment(comment).build()); + inboundMessageCount.incrementAndGet(); + } + public void simulateEndpointEvent(String jsonMessage) { events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); inboundMessageCount.incrementAndGet(); @@ -95,18 +98,22 @@ public void simulateMessageEvent(String jsonMessage) { } - void startContainer() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } + @AfterAll + static void cleanup() { + container.stop(); + } + @BeforeEach void setUp() { - startContainer(); webClientBuilder = WebClient.builder().baseUrl(host); - objectMapper = new ObjectMapper(); - transport = new TestSseClientTransport(webClientBuilder, objectMapper); + transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER); transport.connect(Function.identity()).block(); } @@ -115,11 +122,6 @@ void afterEach() { if (transport != null) { assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - cleanup(); - } - - void cleanup() { - container.stop(); } @Test @@ -129,12 +131,60 @@ void testEndpointEventHandling() { @Test void constructorValidation() { - assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> new WebFluxSseClientTransport(null, JSON_MAPPER)) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("WebClient.Builder must not be null"); assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ObjectMapper must not be null"); + .hasMessageContaining("jsonMapper must not be null"); + } + + @Test + void testBuilderPattern() { + // Test default builder + WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build(); + assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with custom ObjectMapper + ObjectMapper customMapper = new ObjectMapper(); + WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder) + .jsonMapper(new JacksonMcpJsonMapper(customMapper)) + .build(); + assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with custom SSE endpoint + WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder) + .sseEndpoint("/custom-sse") + .build(); + assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with all custom parameters + WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder) + .sseEndpoint("/custom-sse") + .build(); + assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); + } + + @Test + void testCommentSseMessage() { + // If the line starts with a character (:) are comment lins and should be ingored + // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + + CopyOnWriteArrayList droppedErrors = new CopyOnWriteArrayList<>(); + reactor.core.publisher.Hooks.onErrorDropped(droppedErrors::add); + + try { + // Simulate receiving the SSE comment line + transport.simulateSseComment("sse comment"); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + assertThat(droppedErrors).hasSize(0); + } + finally { + reactor.core.publisher.Hooks.resetOnErrorDropped(); + } } @Test @@ -240,7 +290,7 @@ void testRetryBehavior() { // Create a WebClient that simulates connection failures WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host"); - WebFluxSseClientTransport failingTransport = new WebFluxSseClientTransport(failingWebClientBuilder); + WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build(); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..3db0bbd3a --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers using Spring WebFlux infrastructure. + * + *

    + * This test class validates the end-to-end flow of transport context propagation in MCP + * communication for asynchronous client and server implementations. It tests various + * combinations of client types and server transport mechanisms (stateless, streamable, + * SSE) to ensure proper context handling across different configurations. + * + *

    Context Propagation Flow

    + *
      + *
    1. Client sets a value in its transport context via thread-local Reactor context
    2. + *
    3. Client-side context provider extracts the value and adds it as an HTTP header to + * the request
    4. + *
    5. Server-side context extractor reads the header from the incoming request
    6. + *
    7. Server handler receives the extracted context and returns the value as the tool + * call result
    8. + *
    9. Test verifies the round-trip context propagation was successful
    10. + *
    + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HEADER_NAME = "x-test"; + + // Async client context provider + ExchangeFilterFunction asyncClientContextProvider = (request, next) -> Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = transportContext.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }); + + // Tools + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + // Server context extractor + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + // Server transports + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + // Async clients + private final McpAsyncClient asyncStreamableClient = McpClient + .async(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopHttpServer(); + } + + @Test + void asyncClientStatelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..94e16e73e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP client and + * server using synchronous operations in a Spring WebFlux environment. + *

    + * This test class validates the end-to-end flow of transport context propagation across + * different WebFlux-based MCP transport implementations + * + *

    + * The test scenario follows these steps: + *

      + *
    1. The client stores a value in a thread-local variable
    2. + *
    3. The client's transport context provider reads this value and includes it in the MCP + * context
    4. + *
    5. A WebClient filter extracts the context value and adds it as an HTTP header + * (x-test)
    6. + *
    7. The server's {@link McpTransportContextExtractor} reads the header from the + * request
    8. + *
    9. The server returns the header value as the tool call result, validating the + * round-trip
    10. + *
    + * + *

    + * This test demonstrates how custom context can be propagated through HTTP headers in a + * reactive WebFlux environment, enabling features like authentication tokens, correlation + * IDs, or other metadata to flow between MCP client and server. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @since 1.0.0 + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see WebFluxStatelessServerTransport + * @see WebFluxStreamableServerTransportProvider + * @see WebFluxSseServerTransportProvider + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, request) -> { + return new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + }; + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient.sync(WebFluxSseClientTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()).transportContextProvider(clientContextProvider).build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopHttpServer(); + } + + @Test + void statelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 1ed0d99b5..fe0314687 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -4,9 +4,8 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,28 +15,32 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; private DisposableServer httpServer; - @Override - protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + private McpServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxSseServerTransportProvider.Builder().messageEndpoint(MESSAGE_ENDPOINT) + .build(); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } - return transport; + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 4db00dd47..67ef90bdf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -4,9 +4,8 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,30 +15,34 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - private static final int PORT = 8182; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; private DisposableServer httpServer; - private WebFluxSseServerTransport transport; + private WebFluxSseServerTransportProvider transportProvider; @Override - protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - return transport; + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebFluxSseServerTransportProvider.Builder().messageEndpoint(MESSAGE_ENDPOINT).build(); + return transportProvider; } @Override protected void onStart() { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java new file mode 100644 index 000000000..9b5a80f16 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using + * {@link WebFluxStreamableServerTransportProvider}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = WebFluxStreamableServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java new file mode 100644 index 000000000..6a47ba3ae --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using + * {@link WebFluxStreamableServerTransportProvider}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = WebFluxStreamableServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java index 0ab72a99f..dfb004e9b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java new file mode 100644 index 000000000..67347573c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.utils; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.createDefault(); + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java new file mode 100644 index 000000000..55129d481 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.utils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.web.reactive.function.server.HandlerFilterFunction; +import org.springframework.web.reactive.function.server.HandlerFunction; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Simple {@link HandlerFilterFunction} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingExchangeFilterFunction implements HandlerFilterFunction { + + private final List calls = new ArrayList<>(); + + @Override + public Mono filter(ServerRequest request, HandlerFunction next) { + Map headers = request.headers() + .asHttpHeaders() + .keySet() + .stream() + .collect(Collectors.toMap(String::toLowerCase, k -> String.join(",", request.headers().header(k)))); + + var cr = request.bodyToMono(String.class).defaultIfEmpty("").map(body -> { + this.calls.add(new Call(request.method(), headers, body)); + return ServerRequest.from(request).body(body).build(); + }); + + return cr.flatMap(next::handle); + + } + + public List getCalls() { + return List.copyOf(calls); + } + + public record Call(HttpMethod method, Map headers, String body) { + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 5ad73374a..abc831d13 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - - + + - + diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 0eebdd2b3..df18b1b8b 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,13 +6,13 @@ io.modelcontextprotocol.sdk mcp-parent - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc jar - Spring Web MVC implementation of the Java MCP SSE transport - + Spring Web MVC transports + Web MVC implementation for the SSE and Streamable Http Server transports https://github.com/modelcontextprotocol/java-sdk @@ -22,23 +22,36 @@ - + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + + + io.modelcontextprotocol.sdk mcp - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT + + + + org.springframework + spring-webmvc + ${springframework.version} io.modelcontextprotocol.sdk mcp-test - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT test - org.springframework - spring-webmvc - ${springframework.version} + io.modelcontextprotocol.sdk + mcp-spring-webflux + 0.18.0-SNAPSHOT + test @@ -77,6 +90,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + org.testcontainers junit-jupiter @@ -122,7 +141,14 @@ test + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + - \ No newline at end of file + diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java deleted file mode 100644 index 00928ec7f..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java +++ /dev/null @@ -1,382 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -import org.springframework.http.HttpStatus; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.RouterFunctions; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.function.ServerResponse; -import org.springframework.web.servlet.function.ServerResponse.SseBuilder; - -/** - * Server-side implementation of the Model Context Protocol (MCP) transport layer using - * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides - * a bridge between synchronous WebMVC operations and reactive programming patterns to - * maintain compatibility with the reactive transport interface. - * - *

    - * Key features: - *

      - *
    • Implements bidirectional communication using HTTP POST for client-to-server - * messages and SSE for server-to-client messages
    • - *
    • Manages client sessions with unique IDs for reliable message delivery
    • - *
    • Supports graceful shutdown with proper session cleanup
    • - *
    • Provides JSON-RPC message handling through configured endpoints
    • - *
    • Includes built-in error handling and logging
    • - *
    - * - *

    - * The transport operates on two main endpoints: - *

      - *
    • {@code /sse} - The SSE endpoint where clients establish their event stream - * connection
    • - *
    • A configurable message endpoint where clients send their JSON-RPC messages via HTTP - * POST
    • - *
    - * - *

    - * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client - * sessions in a thread-safe manner. Each client session is assigned a unique ID and - * maintains its own SSE connection. - * - * @author Christian Tzolov - * @author Alexandros Pappas - * @see ServerMcpTransport - * @see RouterFunction - */ -public class WebMvcSseServerTransport implements ServerMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransport.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. - */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - private final ObjectMapper objectMapper; - - private final String messageEndpoint; - - private final String sseEndpoint; - - private final RouterFunction routerFunction; - - /** - * Map of active client sessions, keyed by session ID. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - /** - * The function to process incoming JSON-RPC messages and produce responses. - */ - private Function, Mono> connectHandler; - - /** - * Constructs a new WebMvcSseServerTransport instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - - this.objectMapper = objectMapper; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) - .build(); - } - - /** - * Constructs a new WebMvcSseServerTransport instance with the default SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Sets up the message handler for this transport. In the WebMVC SSE implementation, - * this method only stores the handler for later use, as connections are initiated by - * clients rather than the server. - * @param connectionHandler The function to process incoming JSON-RPC messages and - * produce responses - * @return An empty Mono since the server doesn't initiate connections - */ - @Override - public Mono connect( - Function, Mono> connectionHandler) { - this.connectHandler = connectionHandler; - // Server-side transport doesn't initiate connections - return Mono.empty(); - } - - /** - * Broadcasts a message to all connected clients through their SSE connections. The - * message is serialized to JSON and sent as an SSE event with type "message". If any - * errors occur during sending to a particular client, they are logged but don't - * prevent sending to other clients. - * @param message The JSON-RPC message to broadcast to all connected clients - * @return A Mono that completes when the broadcast attempt is finished - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return Mono.fromRunnable(() -> { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return; - } - - try { - String jsonText = objectMapper.writeValueAsString(message); - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - sessions.values().forEach(session -> { - try { - session.sseBuilder.id(session.id).event(MESSAGE_EVENT_TYPE).data(jsonText); - } - catch (Exception e) { - logger.error("Failed to send message to session {}: {}", session.id, e.getMessage()); - session.sseBuilder.error(e); - } - }); - } - catch (IOException e) { - logger.error("Failed to serialize message: {}", e.getMessage()); - } - }); - } - - /** - * Handles new SSE connection requests from clients by creating a new session and - * establishing an SSE connection. This method: - *

      - *
    • Generates a unique session ID
    • - *
    • Creates a new ClientSession with an SSE builder
    • - *
    • Sends an initial endpoint event to inform the client where to send - * messages
    • - *
    • Maintains the session in the sessions map
    • - *
    - * @param request The incoming server request - * @return A ServerResponse configured for SSE communication, or an error response if - * the server is shutting down or the connection fails - */ - private ServerResponse handleSseConnection(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - - // Send initial endpoint event - try { - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onComplete(() -> { - logger.debug("SSE connection completed for session: {}", sessionId); - sessions.remove(sessionId); - }); - sseBuilder.onTimeout(() -> { - logger.debug("SSE connection timed out for session: {}", sessionId); - sessions.remove(sessionId); - }); - - ClientSession session = new ClientSession(sessionId, sseBuilder); - this.sessions.put(sessionId, session); - - try { - session.sseBuilder.id(session.id).event(ENDPOINT_EVENT_TYPE).data(messageEndpoint); - } - catch (Exception e) { - logger.error("Failed to poll event from session queue: {}", e.getMessage()); - sseBuilder.error(e); - } - }, Duration.ZERO); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); - sessions.remove(sessionId); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); - } - } - - /** - * Handles incoming JSON-RPC messages from clients. This method: - *
      - *
    • Deserializes the request body into a JSON-RPC message
    • - *
    • Processes the message through the configured connect handler
    • - *
    • Returns appropriate HTTP responses based on the processing result
    • - *
    - * @param request The incoming server request containing the JSON-RPC message - * @return A ServerResponse indicating success (200 OK) or appropriate error status - * with error details in case of failures - */ - private ServerResponse handleMessage(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - try { - String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - - // Convert the message to a Mono, apply the handler, and block for the - // response - @SuppressWarnings("unused") - McpSchema.JSONRPCMessage response = Mono.just(message).transform(connectHandler).block(); - - return ServerResponse.ok().build(); - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().body(new McpError("Invalid message format")); - } - catch (Exception e) { - logger.error("Error handling message: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); - } - } - - /** - * Represents an active client session with its associated SSE connection. Each - * session maintains: - *
      - *
    • A unique session identifier
    • - *
    • An SSE builder for sending server events to the client
    • - *
    • Logging of session lifecycle events
    • - *
    - */ - private static class ClientSession { - - private final String id; - - private final SseBuilder sseBuilder; - - /** - * Creates a new client session with the specified ID and SSE builder. - * @param id The unique identifier for this session - * @param sseBuilder The SSE builder for sending server events to the client - */ - ClientSession(String id, SseBuilder sseBuilder) { - this.id = id; - this.sseBuilder = sseBuilder; - logger.debug("Session {} initialized with SSE emitter", id); - } - - /** - * Closes this session by completing the SSE connection. Any errors during - * completion are logged but do not prevent the session from being marked as - * closed. - */ - void close() { - logger.debug("Closing session: {}", id); - try { - sseBuilder.complete(); - logger.debug("Successfully completed SSE emitter for session {}", id); - } - catch (Exception e) { - logger.warn("Failed to complete SSE emitter for session {}: {}", id, e.getMessage()); - // sseBuilder.error(e); - } - } - - } - - /** - * Converts data from one type to another using the configured ObjectMapper. This is - * particularly useful for handling complex JSON-RPC parameter types. - * @param data The source data object to convert - * @param typeRef The target type reference - * @return The converted object of type T - * @param The target type - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. This method: - *
      - *
    • Sets the closing flag to prevent new connections
    • - *
    • Closes all active SSE connections
    • - *
    • Removes all session records
    • - *
    - * @return A Mono that completes when all cleanup operations are finished - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - this.isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - - sessions.values().forEach(session -> { - String sessionId = session.id; - session.close(); - sessions.remove(sessionId); - }); - - logger.debug("Graceful shutdown completed"); - }); - } - - /** - * Returns the RouterFunction that defines the HTTP endpoints for this transport. The - * router function handles two endpoints: - *
      - *
    • GET /sse - For establishing SSE connections
    • - *
    • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
    • - *
    - * @return The configured RouterFunction for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java new file mode 100644 index 000000000..6c35de56d --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -0,0 +1,568 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Server-side implementation of the Model Context Protocol (MCP) transport layer using + * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides + * a bridge between synchronous WebMVC operations and reactive programming patterns to + * maintain compatibility with the reactive transport interface. + * + *

    + * Key features: + *

      + *
    • Implements bidirectional communication using HTTP POST for client-to-server + * messages and SSE for server-to-client messages
    • + *
    • Manages client sessions with unique IDs for reliable message delivery
    • + *
    • Supports graceful shutdown with proper session cleanup
    • + *
    • Provides JSON-RPC message handling through configured endpoints
    • + *
    • Includes built-in error handling and logging
    • + *
    + * + *

    + * The transport operates on two main endpoints: + *

      + *
    • {@code /sse} - The SSE endpoint where clients establish their event stream + * connection
    • + *
    • A configurable message endpoint where clients send their JSON-RPC messages via HTTP + * POST
    • + *
    + * + *

    + * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client + * sessions in a thread-safe manner. Each client session is assigned a unique ID and + * maintains its own SSE connection. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see RouterFunction + */ +public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + public static final String SESSION_ID = "sessionId"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final McpJsonMapper jsonMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final String baseUrl; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param keepAliveInterval The interval for sending keep-alive messages to clients. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @throws IllegalArgumentException if any parameter is null + */ + private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + Assert.notNull(baseUrl, "Message base URL must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.baseUrl = baseUrl; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * The message is serialized to JSON and sent as an SSE event with type "message". If + * any errors occur during sending to a particular client, they are logged but don't + * prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Initiates a graceful shutdown of the transport. This method: + *

      + *
    • Sets the closing flag to prevent new connections
    • + *
    • Closes all active SSE connections
    • + *
    • Removes all session records
    • + *
    + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()).doFirst(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles two endpoints: + *
      + *
    • GET /sse - For establishing SSE connections
    • + *
    • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
    • + *
    + * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients by creating a new session and + * establishing an SSE connection. This method: + *
      + *
    • Generates a unique session ID
    • + *
    • Creates a new session with a WebMvcMcpSessionTransport
    • + *
    • Sends an initial endpoint event to inform the client where to send + * messages
    • + *
    • Maintains the session in the sessions map
    • + *
    + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response if + * the server is shutting down or the connection fails + */ + private ServerResponse handleSseConnection(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + // Send initial endpoint event + return ServerResponse.sse(sseBuilder -> { + WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sseBuilder); + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); + this.sessions.put(sessionId, session); + + try { + sseBuilder.event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId)); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event: {}", e.getMessage()); + this.sessions.remove(sessionId); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + return UriComponentsBuilder.fromUriString(this.baseUrl) + .path(this.messageEndpoint) + .queryParam(SESSION_ID, sessionId) + .build() + .toUriString(); + } + + /** + * Handles incoming JSON-RPC messages from clients. This method: + *
      + *
    • Deserializes the request body into a JSON-RPC message
    • + *
    • Processes the message through the session's handle method
    • + *
    • Returns appropriate HTTP responses based on the processing result
    • + *
    + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success (200 OK) or appropriate error status + * with error details in case of failures + */ + private ServerResponse handleMessage(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (request.param(SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); + } + + String sessionId = request.param(SESSION_ID).get(); + McpServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId)); + } + + try { + final McpTransportContext transportContext = this.contextExtractor.extract(request); + + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + // Process the message through the session's handle method + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block + // for + // WebMVC + // compatibility + + return ServerResponse.ok().build(); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpServerTransport for WebMVC SSE sessions. This class handles + * the transport-level communication for a specific client session. + */ + private class WebMvcMcpSessionTransport implements McpServerTransport { + + private final SseBuilder sseBuilder; + + /** + * Lock to ensure thread-safe access to the SSE builder when sending messages. + * This prevents concurrent modifications that could lead to corrupted SSE events. + */ + private final ReentrantLock sseBuilderLock = new ReentrantLock(); + + /** + * Creates a new session transport with the specified SSE builder. + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcMcpSessionTransport(SseBuilder sseBuilder) { + this.sseBuilder = sseBuilder; + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + sseBuilderLock.lock(); + try { + String jsonText = jsonMapper.writeValueAsString(message); + sseBuilder.event(MESSAGE_EVENT_TYPE).data(jsonText); + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage()); + sseBuilder.error(e); + } + finally { + sseBuilderLock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured McpJsonMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @param The target type + * @return The converted object of type T + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + sseBuilderLock.lock(); + try { + sseBuilder.complete(); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder: {}", e.getMessage()); + } + finally { + sseBuilderLock.unlock(); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + sseBuilderLock.lock(); + try { + sseBuilder.complete(); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder: {}", e.getMessage()); + } + finally { + sseBuilderLock.unlock(); + } + } + + } + + /** + * Creates a new Builder instance for configuring and creating instances of + * WebMvcSseServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of WebMvcSseServerTransportProvider. + *

    + * This builder provides a fluent API for configuring and creating instances of + * WebMvcSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String baseUrl = ""; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private Duration keepAliveInterval; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + /** + * Sets the JSON object mapper to use for message serialization/deserialization. + * @param jsonMapper The object mapper to use + * @return This builder instance for method chaining + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint path where clients will send their messages. + * @param messageEndpoint The message endpoint path + * @return This builder instance for method chaining + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the endpoint path where clients will establish SSE connections. + *

    + * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be + * used. + * @param sseEndpoint The SSE endpoint path + * @return This builder instance for method chaining + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

    + * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of WebMvcSseServerTransportProvider with the configured + * settings. + * @return A new WebMvcSseServerTransportProvider instance + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set + */ + public WebMvcSseServerTransportProvider build() { + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); + } + return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java new file mode 100644 index 000000000..4223084ff --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -0,0 +1,240 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebMVC based {@link McpStatelessServerTransport}. + * + *

    + * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport} + * + * @author Christian Tzolov + */ +public class WebMvcStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStatelessServerTransport.class); + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebMVC router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

    + * The router function defines one endpoint handling two HTTP methods: + *

      + *
    • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
    • + *
    • POST {messageEndpoint} - For handling client requests and notifications
    • + *
    + * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private ServerResponse handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private ServerResponse handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).body(jsonrpcResponse); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + return ServerResponse.badRequest() + .body(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStatelessServerTransport}. + *

    + * This builder provides a fluent API for configuring and creating instances of + * WebMvcStatelessServerTransport with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "ObjectMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStatelessServerTransport} with the + * configured settings. + * @return A new WebMvcStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStatelessServerTransport build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new WebMvcStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java new file mode 100644 index 000000000..f2a58d4d8 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -0,0 +1,690 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import io.modelcontextprotocol.json.McpJsonMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This + * implementation provides a bridge between synchronous WebMVC operations and reactive + * programming patterns to maintain compatibility with the reactive transport interface. + * + *

    + * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider} + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see RouterFunction + */ +public class WebMvcStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default base URL for the message endpoint. + */ + public static final String DEFAULT_BASE_URL = ""; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final McpJsonMapper jsonMapper; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new WebMvcStreamableServerTransportProvider instance. + * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @throws IllegalArgumentException if any parameter is null + */ + private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles three endpoints: + *

      + *
    • GET [mcpEndpoint] - For establishing SSE connections and message replay
    • + *
    • POST [mcpEndpoint] - For receiving JSON-RPC messages from clients
    • + *
    • DELETE [mcpEndpoint] - For session deletion (if enabled)
    • + *
    + * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Setup the listening SSE connections and message replay. + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response + */ + private ServerResponse handleGet(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + logger.debug("Handling GET request for session: {}", sessionId); + + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + // Check if this is a replay request + if (!request.headers().header(HttpHeaders.LAST_EVENT_ID).isEmpty()) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + sseBuilder.error(e); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + sseBuilder.error(e); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + }); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handlePost(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) + || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { + return ServerResponse.badRequest() + .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + return ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .body(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, + null)); + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + // Handle other messages that require a session + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .body(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("Request response stream completed for session: {}", sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("Request response stream timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The incoming server request + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handleDelete(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + return ServerResponse.ok().build(); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class + * handles the transport-level communication for a specific client session. + * + *

    + * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying SSE builder to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + private class WebMvcStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + private final ReentrantLock lock = new ReentrantLock(); + + private volatile boolean closed = false; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Streamable session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = jsonMapper.writeValueAsString(message); + this.sseBuilder.id(messageId != null ? messageId : this.sessionId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + try { + this.sseBuilder.error(e); + } + catch (Exception errorException) { + logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId, + errorException.getMessage()); + } + } + finally { + this.lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured McpJsonMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + WebMvcStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + this.sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + finally { + this.lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStreamableServerTransportProvider}. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Duration keepAliveInterval; + + /** + * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The McpJsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be created to periodically check and send keep-alive messages to clients. + * @param keepAliveInterval The interval duration for keep-alive messages, or null + * to disable keep-alive + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebMvcStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStreamableServerTransportProvider build() { + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + return new WebMvcStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java new file mode 100644 index 000000000..cc9945436 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java @@ -0,0 +1,297 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * servers using Spring WebMVC transport implementations. + * + *

    + * This test class validates the end-to-end flow of transport context propagation across + * different MCP transport mechanisms in a Spring WebMVC environment. It demonstrates how + * contextual information can be passed from client to server through HTTP headers and + * properly extracted and utilized on the server side. + * + *

    Transport Types Tested

    + *
      + *
    • Stateless: Tests context propagation with + * {@link WebMvcStatelessServerTransport} where each request is independent
    • + *
    • Streamable HTTP: Tests context propagation with + * {@link WebMvcStreamableServerTransportProvider} supporting stateful server + * sessions
    • + *
    • Server-Sent Events (SSE): Tests context propagation with + * {@link WebMvcSseServerTransportProvider} for long-lived connections
    • + *
    + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class McpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private TomcatServer tomcatServer; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private static final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private static final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private static McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + String headerValue = r.servletRequest().getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private static final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(TestStatelessConfig.class); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + @Test + void streamableServer() { + + startTomcat(TestStreamableHttpConfig.class); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + @Test + void sseServer() { + startTomcat(TestSseConfig.class); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + private void startTomcat(Class componentClass) { + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, componentClass); + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcatServer != null && tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Configuration + @EnableWebMvc + static class TestStatelessConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder().contextExtractor(serverContextExtractor).build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestStreamableHttpConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport() { + + return WebMvcStreamableServerTransportProvider.builder().contextExtractor(serverContextExtractor).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpStreamableServer(WebMvcStreamableServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestSseConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransport() { + + return WebMvcSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpSseServer(WebMvcSseServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java new file mode 100644 index 000000000..8625b6a70 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java @@ -0,0 +1,64 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; + +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + TomcatTestUtil() { + // Prevent instantiation + } + + public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) { + } + + public static TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { + + // Set up Tomcat first + var tomcat = new Tomcat(); + tomcat.setPort(port); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext(contextPath, baseDir); + + // Create and configure Spring WebMvc context + var appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(componentClass); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + // Configure and start the connector with async support + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return new TomcatServer(tomcat, appContext); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java new file mode 100644 index 000000000..36aaa27fb --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder().mcpEndpoint(MCP_ENDPOINT).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java new file mode 100644 index 000000000..2f75551eb --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder().mcpEndpoint(MCP_ENDPOINT).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index a819920c4..ccf3170c9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -4,9 +4,8 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; @@ -25,32 +24,34 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private Tomcat tomcat; - private WebMvcSseServerTransport transport; + private McpServerTransportProvider transportProvider; @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcSseServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } private AnnotationConfigWebApplicationContext appContext; - @Override - protected ServerMcpTransport createMcpTransport() { + private McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -69,11 +70,10 @@ protected ServerMcpTransport createMcpTransport() { appContext.refresh(); // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); // Add servlet to Tomcat and get the wrapper var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); @@ -88,7 +88,12 @@ protected ServerMcpTransport createMcpTransport() { throw new RuntimeException("Failed to start Tomcat", e); } - return transport; + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); } @Override @@ -97,8 +102,8 @@ protected void onStart() { @Override protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); + if (transportProvider != null) { + transportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java new file mode 100644 index 000000000..d8d26af48 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +class WebMvcSseCustomContextPathTests { + + private static final String CUSTOM_CONTEXT_PATH = "/app/1"; + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + + clientBuilder = McpClient.sync(clientTransport); + + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + assertThat(client.initialize()).isNotNull(); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + + return WebMvcSseServerTransportProvider.builder() + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + // return new WebMvcSseServerTransportProvider(new ObjectMapper(), + // CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + // WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 62f696375..045f9b3dd 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -3,138 +3,120 @@ */ package io.modelcontextprotocol.server; +import static org.assertj.core.api.Assertions.assertThat; + import java.time.Duration; -import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; +import java.util.stream.Stream; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.test.StepVerifier; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.awaitility.Awaitility.await; +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import reactor.core.scheduler.Schedulers; -public class WebMvcSseIntegrationTests { +@Timeout(15) +class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { - private static final int PORT = 8183; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private WebMvcSseServerTransport mcpServerTransport; + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext + .create(Map.of("important", "value")); - McpClient.SyncSpec clientBuilder; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build()) + .requestTimeout(Duration.ofHours(10))); + } @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcSseServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } - private Tomcat tomcat; - - private AnnotationConfigWebApplicationContext appContext; + private TomcatTestUtil.TomcatServer tomcatServer; @BeforeEach public void before() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addServletMappingDecoded("/*", "dispatcherServlet"); + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); try { - // Configure and start the connector with async support - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests - tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + prepareClients(PORT, MESSAGE_ENDPOINT); + + // Get the transport from Spring context + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } @AfterEach public void after() { - if (mcpServerTransport != null) { - mcpServerTransport.closeGracefully().block(); + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); } - if (appContext != null) { - appContext.close(); + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); } - if (tomcat != null) { + if (tomcatServer.tomcat() != null) { try { - tomcat.stop(); - tomcat.destroy(); + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); } catch (LifecycleException e) { throw new RuntimeException("Failed to stop Tomcat", e); @@ -142,366 +124,14 @@ public void after() { } } - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @Test - void testCreateMessageWithoutSamplingCapabilities() { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @Test - void testCreateMessageSuccess() throws InterruptedException { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsWithoutCapability() { - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); - - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsWithMultipleConsumers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Close server while subscription is active - mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testInitialize() { - - var mcpServer = McpServer.sync(mcpServerTransport).build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.close(); - mcpServer.close(); + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 249b4deaf..66d6d3ae9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -4,9 +4,7 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; @@ -25,24 +23,24 @@ class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private Tomcat tomcat; - private WebMvcSseServerTransport transport; + private WebMvcSseServerTransportProvider transportProvider; @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcSseServerTransportProvider.builder().messageEndpoint(MESSAGE_ENDPOINT).build(); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -50,7 +48,11 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr private AnnotationConfigWebApplicationContext appContext; @Override - protected ServerMcpTransport createMcpTransport() { + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private WebMvcSseServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -69,11 +71,10 @@ protected ServerMcpTransport createMcpTransport() { appContext.refresh(); // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); // Add servlet to Tomcat and get the wrapper var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); @@ -88,7 +89,7 @@ protected ServerMcpTransport createMcpTransport() { throw new RuntimeException("Failed to start Tomcat", e); } - return transport; + return transportProvider; } @Override @@ -97,8 +98,8 @@ protected void onStart() { @Override protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); + if (transportProvider != null) { + transportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java new file mode 100644 index 000000000..8c7b0a85e --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.stream.Stream; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.AbstractStatelessIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import reactor.core.scheduler.Schedulers; + +@Timeout(15) +class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStatelessServerTransport mcpServerTransport; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder().messageEndpoint(MESSAGE_ENDPOINT).build(); + + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport statelessServerTransport) { + return statelessServerTransport.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransport); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + prepareClients(PORT, MESSAGE_ENDPOINT); + + // Get the transport from Spring context + this.mcpServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); + + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (this.mcpServerTransport != null) { + this.mcpServerTransport.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java new file mode 100644 index 000000000..cb7b4a2a0 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import reactor.core.scheduler.Schedulers; + +@Timeout(15) +class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStreamableServerTransportProvider mcpServerTransportProvider; + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .mcpEndpoint(MESSAGE_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .build())); + + // Get the transport from Spring context + this.mcpServerTransportProvider = tomcatServer.appContext() + .getBean(WebMvcStreamableServerTransportProvider.class); + + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java new file mode 100644 index 000000000..1074e8a35 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for WebMvcSseServerTransportProvider + * + * @author lance + */ +class WebMvcSseServerTransportProviderTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_CONTEXT_PATH = ""; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + + clientBuilder = McpClient.sync(transport); + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } + + @Test + void validBaseUrl() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + assertThat(client.initialize()).isNotNull(); + } + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + + return WebMvcSseServerTransportProvider.builder() + .baseUrl("http://localhost:" + PORT + "/") + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .jsonMapper(McpJsonMapper.getDefault()) + .contextExtractor(req -> McpTransportContext.EMPTY) + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index bc1140bb5..d4ccbc173 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -9,16 +9,16 @@ - + - + - + - + diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 717f03198..7fc22e5d2 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT @@ -54,6 +54,11 @@ junit-jupiter-api ${junit.version} + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + org.mockito mockito-core @@ -68,6 +73,11 @@ junit-jupiter ${testcontainers.version} + + org.testcontainers + toxiproxy + ${toxiproxy.version} + org.awaitility @@ -80,6 +90,14 @@ logback-classic ${logback.version} + + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + + + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 000000000..270bc4308 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1760 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Utils; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + finally { + server.closeGracefully().block(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without elicitation capabilities + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccessWithTranportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var transportContextIsNull = new AtomicBoolean(false); + var transportContextIsEmpty = new AtomicBoolean(false); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + transportContextIsNull.set(transportContext == null); + transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); + String ctxValue = (String) transportContext.get("important"); + + try { + String responseBody = "TOOL RESPONSE"; + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(transportContextIsNull.get()).isFalse(); + assertThat(transportContextIsEmpty.get()).isFalse(); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> toolsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + toolsRef.set(toolsUpdate); + } + catch (Exception e) { + e.printStackTrace(); + } + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(toolsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @MethodSource("clientsForTesting") + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java new file mode 100644 index 000000000..240732ebe --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -0,0 +1,659 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; + +public abstract class AbstractStatelessIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected StatelessAsyncSpecification prepareAsyncServerBuilder(); + + abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + finally { + server.closeGracefully().block(); + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((ctx, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpStatelessSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((context, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((ctx, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + // Remove a tool + mcpServer.removeTool("tool1"); + + // Add a new tool + McpStatelessServerFeatures.SyncToolSpecification tool2 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that throws an exception to simulate an error + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + var tool = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + var toolSpec = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index d4e48ea7d..cd8458311 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -9,21 +9,24 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpServerTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} + * A mock implementation of the {@link McpClientTransport} and {@link McpServerTransport} * interfaces. + * + * @deprecated not used. to be removed in the future. */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { +@Deprecated +public class MockMcpTransport implements McpClientTransport, McpServerTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); @@ -90,8 +93,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..338eaf931 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,230 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index cdcba4d1c..bee8f4f16 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,37 +4,55 @@ package io.modelcontextprotocol.client; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; import io.modelcontextprotocol.spec.McpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - /** * Test suite for the {@link McpAsyncClient} that can be used with different * {@link McpTransport} implementations. @@ -44,289 +62,571 @@ */ public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); } - protected void onClose() { + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(20); } - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(getTimeoutDuration()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .sampling(req -> Mono.just(new CreateMessageResult(McpSchema.Role.USER, + new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN))) + .capabilities(ClientCapabilities.builder().roots(true).sampling().build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); } - @AfterEach - void tearDown() { - if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); } - onClose(); + } + + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); + }); } @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(McpSchema.FIRST_PAGE), "listing tools"); } @Test void testListTools() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(McpSchema.FIRST_PAGE))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllTools() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllToolsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); + } + + @ParameterizedTest + @ValueSource(strings = { "success", "error", "debug" }) + void testCallToolWithMessageAnnotations(String messageType) { + McpClientTransport transport = createMcpTransport(); + + withClient(transport, mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.callTool(new McpSchema.CallToolRequest("annotatedMessage", + Map.of("messageType", messageType, "includeImage", true))))) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isNotEqualTo(true); + assertThat(result.content()).isNotEmpty(); + assertThat(result.content()).allSatisfy(content -> { + switch (content.type()) { + case "text": + McpSchema.TextContent textContent = assertInstanceOf(McpSchema.TextContent.class, + content); + assertThat(textContent.text()).isNotEmpty(); + assertThat(textContent.annotations()).isNotNull(); + + switch (messageType) { + case "error": + assertThat(textContent.annotations().priority()).isEqualTo(1.0); + assertThat(textContent.annotations().audience()) + .containsOnly(McpSchema.Role.USER, McpSchema.Role.ASSISTANT); + break; + case "success": + assertThat(textContent.annotations().priority()).isEqualTo(0.7); + assertThat(textContent.annotations().audience()) + .containsExactly(McpSchema.Role.USER); + break; + case "debug": + assertThat(textContent.annotations().priority()).isEqualTo(0.3); + assertThat(textContent.annotations().audience()) + .containsExactly(McpSchema.Role.ASSISTANT); + break; + default: + throw new IllegalStateException("Unexpected value: " + content.type()); + } + break; + case "image": + McpSchema.ImageContent imageContent = assertInstanceOf(McpSchema.ImageContent.class, + content); + assertThat(imageContent.data()).isNotEmpty(); + assertThat(imageContent.annotations()).isNotNull(); + assertThat(imageContent.annotations().priority()).isEqualTo(0.5); + assertThat(imageContent.annotations().audience()).containsExactly(McpSchema.Role.USER); + break; + default: + fail("Unexpected content type: " + content.type()); + } + }); + }) + .verifyComplete(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(McpSchema.FIRST_PAGE), + "listing resources"); } @Test void testListResources() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(McpSchema.FIRST_PAGE))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResources() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResourcesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) + .consumeNextWith(result -> { + assertThat(result.resources()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy( + () -> result.resources().add(Resource.builder().uri("test://uri").name("test").build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(McpSchema.FIRST_PAGE), + "listing " + "prompts"); } @Test void testListPrompts() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(McpSchema.FIRST_PAGE))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllPrompts() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllPromptsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) + .consumeNextWith(result -> { + assertThat(result.prompts()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "test", "test", null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test void testGetPrompt() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); + + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test - @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + AtomicInteger resourceCount = new AtomicInteger(); + withClient(createMcpTransport(), client -> { + Flux resources = client.initialize() + .then(client.listResources(null)) + .flatMapMany(r -> { + List l = r.resources(); + resourceCount.set(l.size()); + return Flux.fromIterable(l); + }) + .flatMap(r -> client.readResource(r)); + + StepVerifier.create(resources) + .recordWith(ArrayList::new) + .thenConsumeWhile(res -> true) + .consumeRecordedWith(readResourceResults -> { + assertThat(readResourceResults.size()).isEqualTo(resourceCount.get()); + for (ReadResourceResult result : readResourceResults) { + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, + content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + assertThat(textContent.uri()).isNotEmpty(); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, + content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + assertThat(blobContent.uri()).isNotNull().isNotEmpty(); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); + } + default -> { + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } + } + } + } + } + }) + .verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(McpSchema.FIRST_PAGE), + "listing resource templates"); } @Test void testListResourceTemplates() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates(McpSchema.FIRST_PAGE))) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); + } - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); + @Test + void testListAllResourceTemplates() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testListAllResourceTemplatesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result.resourceTemplates()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.resourceTemplates() + .add(new McpSchema.ResourceTemplate("test://template", "test", "test", null, null, null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); - - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); + + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -335,105 +635,194 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()) + .expectNextMatches(Objects::nonNull) + .verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) + .experimental(Map.of("feature", Map.of("featureFlag", true))) .roots(true) .sampling() .build(); Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), + client -> + + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); + } // --------------------------------------- // Logging Tests // --------------------------------------- @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block()) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test void testLoggingLevels() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) + .verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); + } + + @Test + void testSampling() { + McpClientTransport transport = createMcpTransport(); + + final String message = "Hello, world!"; + final String response = "Goodbye, world!"; + final int maxTokens = 100; + + AtomicReference receivedPrompt = new AtomicReference<>(); + AtomicReference receivedMessage = new AtomicReference<>(); + AtomicInteger receivedMaxTokens = new AtomicInteger(); + + withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) + .sampling(request -> { + McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, + request.messages().get(0).content()); + receivedPrompt.set(request.systemPrompt()); + receivedMessage.set(messageText.text()); + receivedMaxTokens.set(request.maxTokens()); + + return Mono + .just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response), + "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN)); + }), client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool( + new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)))) + .consumeNextWith(result -> { + // Verify tool response to ensure our sampling response was passed + // through + assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); + assertThat(result.content()).allSatisfy(content -> { + if (!(content instanceof McpSchema.TextContent text)) + return; + + assertThat(text.text()).contains(response); + }); + + // Verify sampling request parameters received in our callback + assertThat(receivedPrompt.get()).isNotEmpty(); + assertThat(receivedMessage.get()).endsWith(message); // Prefixed + assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); + }) + .verifyComplete(); + }); + } + + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + List receivedNotifications = new CopyOnWriteArrayList<>(); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + receivedNotifications.add(notification); + sink.tryEmitNext(notification); + return Mono.empty(); + }), client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + StepVerifier.create(client.callTool(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + }).verifyComplete(); + + // Use StepVerifier to verify the progress notifications via the sink + StepVerifier.create(sink.asFlux()).expectNextCount(2).thenCancel().verify(Duration.ofSeconds(3)); + + assertThat(receivedNotifications).hasSize(2); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index aeed06cbf..26d60568a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -4,13 +4,33 @@ package io.modelcontextprotocol.client; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + import java.time.Duration; +import java.util.List; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpError; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -19,18 +39,16 @@ import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; /** * Unit tests for MCP Client Session functionality. @@ -40,41 +58,69 @@ */ public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpSyncClientTests.class); private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); - abstract protected void onStart(); - - abstract protected void onClose(); + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } - protected Duration getTimeoutDuration() { + protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(getTimeoutDuration()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + McpClient.SyncSpec builder = McpClient.sync(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); } - @AfterEach - void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); } - onClose(); + } + + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler + .subscribeOn(Schedulers.boundedElastic())).expectNextCount(1).verifyComplete(); + }); } @Test @@ -82,227 +128,406 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(McpSchema.FIRST_PAGE), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(McpSchema.FIRST_PAGE); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); + }); + } - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + @Test + void testListAllTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); + } + + @ParameterizedTest + @ValueSource(strings = { "success", "error", "debug" }) + void testCallToolWithMessageAnnotations(String messageType) { + McpClientTransport transport = createMcpTransport(); + + withClient(transport, client -> { + client.initialize(); + + McpSchema.CallToolResult result = client.callTool(new McpSchema.CallToolRequest("annotatedMessage", + Map.of("messageType", messageType, "includeImage", true))); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isNotEqualTo(true); + assertThat(result.content()).isNotEmpty(); + assertThat(result.content()).allSatisfy(content -> { + switch (content.type()) { + case "text": + McpSchema.TextContent textContent = assertInstanceOf(McpSchema.TextContent.class, content); + assertThat(textContent.text()).isNotEmpty(); + assertThat(textContent.annotations()).isNotNull(); + + switch (messageType) { + case "error": + assertThat(textContent.annotations().priority()).isEqualTo(1.0); + assertThat(textContent.annotations().audience()).containsOnly(McpSchema.Role.USER, + McpSchema.Role.ASSISTANT); + break; + case "success": + assertThat(textContent.annotations().priority()).isEqualTo(0.7); + assertThat(textContent.annotations().audience()).containsExactly(McpSchema.Role.USER); + break; + case "debug": + assertThat(textContent.annotations().priority()).isEqualTo(0.3); + assertThat(textContent.annotations().audience()) + .containsExactly(McpSchema.Role.ASSISTANT); + break; + default: + throw new IllegalStateException("Unexpected value: " + content.type()); + } + break; + case "image": + McpSchema.ImageContent imageContent = assertInstanceOf(McpSchema.ImageContent.class, content); + assertThat(imageContent.data()).isNotEmpty(); + assertThat(imageContent.annotations()).isNotNull(); + assertThat(imageContent.annotations().priority()).isEqualTo(0.5); + assertThat(imageContent.annotations().audience()).containsExactly(McpSchema.Role.USER); + break; + default: + fail("Unexpected content type: " + content.type()); + } + }); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(McpSchema.FIRST_PAGE), + "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(McpSchema.FIRST_PAGE); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }); + } - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + @Test + void testListAllResources() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + withClient(createMcpTransport(), mcpSyncClient -> { + + int readResourceCount = 0; + + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull(); + assertThat(resources.resources()).isNotNull(); + + assertThat(resources.resources()).isNotNull().isNotEmpty(); + + // Test reading each resource individually for better error isolation + for (Resource resource : resources.resources()) { + ReadResourceResult result = mcpSyncClient.readResource(resource); + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + readResourceCount++; + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + // Verify URI consistency + assertThat(textContent.uri()).isEqualTo(resource.uri()); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + // Verify URI consistency + assertThat(blobContent.uri()).isEqualTo(resource.uri()); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); + } + default -> { + // More flexible handling of additional MIME types + // Log the unexpected type for debugging but don't fail + // the test + logger.warn("Warning: Encountered unexpected MIME type: {} for resource: {}", + content.mimeType(), resource.uri()); + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } + } + } + } + } - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + // Assert that we read exactly 10 resources + assertThat(readResourceCount).isEqualTo(10); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(McpSchema.FIRST_PAGE), + "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(McpSchema.FIRST_PAGE); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); + } + + @Test + void testListAllResourceTemplates() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(); + + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -310,19 +535,20 @@ void testNotificationHandlers() { AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -331,40 +557,125 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); + } - var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); + @Test + void testSampling() { + McpClientTransport transport = createMcpTransport(); - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + final String message = "Hello, world!"; + final String response = "Goodbye, world!"; + final int maxTokens = 100; + + AtomicReference receivedPrompt = new AtomicReference<>(); + AtomicReference receivedMessage = new AtomicReference<>(); + AtomicInteger receivedMaxTokens = new AtomicInteger(); + + withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) + .sampling(request -> { + McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, + request.messages().get(0).content()); + receivedPrompt.set(request.systemPrompt()); + receivedMessage.set(messageText.text()); + receivedMaxTokens.set(request.maxTokens()); + + return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response), + "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN); + }), client -> { + client.initialize(); + + McpSchema.CallToolResult result = client.callTool( + new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))); + + // Verify tool response to ensure our sampling response was passed through + assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); + assertThat(result.content()).allSatisfy(content -> { + if (!(content instanceof McpSchema.TextContent text)) + return; + + assertThat(text.text()).contains(response); + }); + + // Verify sampling request parameters received in our callback + assertThat(receivedPrompt.get()).isNotEmpty(); + assertThat(receivedMessage.get()).endsWith(message); // Prefixed + assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); + }); } + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + void testProgressConsumer() { + AtomicInteger progressNotificationCount = new AtomicInteger(0); + List receivedNotifications = new CopyOnWriteArrayList<>(); + CountDownLatch latch = new CountDownLatch(2); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + System.out.println("Received progress notification: " + notification); + receivedNotifications.add(notification); + progressNotificationCount.incrementAndGet(); + latch.countDown(); + }), client -> { + client.initialize(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + CallToolResult result = client.callTool(request); + + assertThat(result).isNotNull(); + + try { + // Wait for progress notifications to be processed + latch.await(3, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertThat(progressNotificationCount.get()).isEqualTo(2); + + assertThat(receivedNotifications).isNotEmpty(); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index ca5783d0d..d6677ec9a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -7,7 +7,6 @@ import java.time.Duration; import java.util.List; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -17,21 +16,23 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. * * @author Christian Tzolov */ @@ -43,7 +44,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -64,84 +65,210 @@ void tearDown() { // Server Lifecycle Tests // --------------------------------------- - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "sse", "streamable" }) + void testConstructorWithInvalidArguments(String serverType) { + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test + @Deprecated void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test + void testAddToolCall() { + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + List specs = List.of( + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); + } + @Test void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -151,26 +278,27 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -184,40 +312,55 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } + @Test + void testNotifyResourcesUpdated() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier + .create(mcpAsyncServer + .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); }); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -226,41 +369,222 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + // --------------------------------------- // Prompts Tests // --------------------------------------- @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -268,32 +592,30 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); }); } @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -301,12 +623,10 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -315,15 +635,14 @@ void testRemovePromptWithoutCapability() { void testRemovePrompt() { String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specification) .build(); StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); @@ -333,15 +652,11 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) .doesNotThrowAnyException(); @@ -352,14 +667,13 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -377,12 +691,11 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -391,9 +704,8 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -404,60 +716,11 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) .doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index f8b957506..0a59d0aae 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -6,7 +6,6 @@ import java.util.List; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -16,19 +15,19 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpServerTransportProvider} implementations. * * @author Christian Tzolov */ @@ -40,7 +39,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { } @@ -64,117 +63,229 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test + @Deprecated void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddToolCall() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test + @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + List specs = List.of( + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); } @Test void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -183,63 +294,257 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testNotifyResourcesUpdated() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer + .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) + .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + void testAddResourceWithNullSpecification() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) - .isInstanceOf(McpError.class) + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Resource must not be null"); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -248,75 +553,73 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + void testAddPromptWithNullSpecification() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); } @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specification) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -324,23 +627,21 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); } })) .build(); - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test with multiple consumers @@ -348,84 +649,33 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) + }, (exchange, roots) -> consumer2Called[0] = true)) .build(); assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java new file mode 100644 index 000000000..dbbf1a537 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java @@ -0,0 +1,32 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +public class TestUtil { + + TestUtil() { + // Prevent instantiation + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..723965519 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.getDefault(); + +} \ No newline at end of file diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp/pom.xml b/mcp/pom.xml index 2170ffefe..0e0ed1288 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp jar @@ -20,175 +20,20 @@ git@github.com/modelcontextprotocol/java-sdk.git - - - - biz.aQute.bnd - bnd-maven-plugin - ${bnd-maven-plugin.version} - - - bnd-process - - bnd-process - - - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - ${project.build.outputDirectory}/META-INF/MANIFEST.MF - - - - - - - org.slf4j - slf4j-api - ${slf4j-api.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - - io.projectreactor - reactor-core - - - - org.springframework - spring-webmvc - ${springframework.version} - test - - - - - io.projectreactor.netty - reactor-netty-http - test - - - - - org.springframework - spring-context - ${springframework.version} - test - - - - org.springframework - spring-test - ${springframework.version} - test - - - - org.assertj - assertj-core - ${assert4j.version} - test + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT - - org.junit.jupiter - junit-jupiter-api - ${junit.version} - test - - - org.mockito - mockito-core - ${mockito.version} - test - - - io.projectreactor - reactor-test - test - - - org.testcontainers - junit-jupiter - ${testcontainers.version} - test - - - - org.awaitility - awaitility - ${awaitility.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - net.javacrumbs.json-unit - json-unit-assertj - ${json-unit-assertj.version} - test - - - - jakarta.servlet - jakarta.servlet-api - ${jakarta.servlet.version} - provided + io.modelcontextprotocol.sdk + mcp-core + 0.18.0-SNAPSHOT - - - - org.apache.tomcat.embed - tomcat-embed-core - ${tomcat.version} - test - - - org.apache.tomcat.embed - tomcat-embed-websocket - ${tomcat.version} - test - - - \ No newline at end of file + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java deleted file mode 100644 index 7fc679937..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ /dev/null @@ -1,199 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.client.transport; - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.regex.Pattern; - -/** - * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive - * stream processing. This client establishes a connection to an SSE endpoint and - * processes the incoming event stream, parsing SSE-formatted messages into structured - * events. - * - *

    - * The client supports standard SSE event fields including: - *

      - *
    • event - The event type (defaults to "message" if not specified)
    • - *
    • id - The event ID
    • - *
    • data - The event payload data
    • - *
    - * - *

    - * Events are delivered to a provided {@link SseEventHandler} which can process events and - * handle any errors that occur during the connection. - * - * @author Christian Tzolov - * @see SseEventHandler - * @see SseEvent - */ -public class FlowSseClient { - - private final HttpClient httpClient; - - /** - * Pattern to extract the data content from SSE data field lines. Matches lines - * starting with "data:" and captures the remaining content. - */ - private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event ID from SSE id field lines. Matches lines starting - * with "id:" and captures the ID value. - */ - private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event type from SSE event field lines. Matches lines - * starting with "event:" and captures the event type. - */ - private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); - - /** - * Record class representing a Server-Sent Event with its standard fields. - * - * @param id the event ID (may be null) - * @param type the event type (defaults to "message" if not specified in the stream) - * @param data the event payload data - */ - public static record SseEvent(String id, String type, String data) { - } - - /** - * Interface for handling SSE events and errors. Implementations can process received - * events and handle any errors that occur during the SSE connection. - */ - public interface SseEventHandler { - - /** - * Called when an SSE event is received. - * @param event the received SSE event containing id, type, and data - */ - void onEvent(SseEvent event); - - /** - * Called when an error occurs during the SSE connection. - * @param error the error that occurred - */ - void onError(Throwable error); - - } - - /** - * Creates a new FlowSseClient with the specified HTTP client. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - */ - public FlowSseClient(HttpClient httpClient) { - this.httpClient = httpClient; - } - - /** - * Subscribes to an SSE endpoint and processes the event stream. - * - *

    - * This method establishes a connection to the specified URL and begins processing the - * SSE stream. Events are parsed and delivered to the provided event handler. The - * connection remains active until either an error occurs or the server closes the - * connection. - * @param url the SSE endpoint URL to connect to - * @param eventHandler the handler that will receive SSE events and error - * notifications - * @throws RuntimeException if the connection fails with a non-200 status code - */ - public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); - - Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(String line) { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - eventBuilder.setLength(0); - } - } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } - } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } - } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } - } - } - subscription.request(1); - } - - @Override - public void onError(Throwable throwable) { - eventHandler.onError(throwable); - } - - @Override - public void onComplete() { - // Handle any remaining event data - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - } - } - }; - - Function, HttpResponse.BodySubscriber> subscriberFactory = subscriber -> HttpResponse.BodySubscribers - .fromLineSubscriber(subscriber); - - CompletableFuture> future = this.httpClient.sendAsync(request, - info -> subscriberFactory.apply(lineSubscriber)); - - future.thenAccept(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); - } - }).exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java deleted file mode 100644 index 35da51970..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.client.transport; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -/** - * Server-Sent Events (SSE) implementation of the - * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE - * transport specification, using Java's HttpClient. - * - *

    - * This transport implementation establishes a bidirectional communication channel between - * client and server using SSE for server-to-client messages and HTTP POST requests for - * client-to-server messages. The transport: - *

      - *
    • Establishes an SSE connection to receive server messages
    • - *
    • Handles endpoint discovery through SSE events
    • - *
    • Manages message serialization/deserialization using Jackson
    • - *
    • Provides graceful connection termination
    • - *
    - * - *

    - * The transport supports two types of SSE events: - *

      - *
    • 'endpoint' - Contains the URL for sending client messages
    • - *
    • 'message' - Contains JSON-RPC message payload
    • - *
    - * - * @author Christian Tzolov - * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.ClientMcpTransport - */ -public class HttpClientSseClientTransport implements ClientMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); - - /** SSE event type for JSON-RPC messages */ - private static final String MESSAGE_EVENT_TYPE = "message"; - - /** SSE event type for endpoint discovery */ - private static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** Default SSE endpoint path */ - private static final String SSE_ENDPOINT = "/sse"; - - /** Base URI for the MCP server */ - private final String baseUri; - - /** SSE client for handling server-sent events. Uses the /sse endpoint */ - private final FlowSseClient sseClient; - - /** - * HTTP client for sending messages to the server. Uses HTTP POST over the message - * endpoint - */ - private final HttpClient httpClient; - - /** JSON object mapper for message serialization/deserialization */ - protected ObjectMapper objectMapper; - - /** Flag indicating if the transport is in closing state */ - private volatile boolean isClosing = false; - - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); - - /** Holds the discovered message endpoint URL */ - private final AtomicReference messageEndpoint = new AtomicReference<>(); - - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); - - /** - * Creates a new transport instance with default HTTP client and object mapper. - * @param baseUri the base URI of the MCP server - */ - public HttpClientSseClientTransport(String baseUri) { - this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - */ - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.hasText(baseUri, "baseUri must not be empty"); - Assert.notNull(clientBuilder, "clientBuilder must not be null"); - this.baseUri = baseUri; - this.objectMapper = objectMapper; - this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); - this.sseClient = new FlowSseClient(this.httpClient); - } - - /** - * Establishes the SSE connection with the server and sets up message handling. - * - *

    - * This method: - *

      - *
    • Initiates the SSE connection
    • - *
    • Handles endpoint discovery events
    • - *
    • Processes incoming JSON-RPC messages
    • - *
    - * @param handler the function to process received JSON-RPC messages - * @return a Mono that completes when the connection is established - */ - @Override - public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); - - sseClient.subscribe(this.baseUri + SSE_ENDPOINT, new FlowSseClient.SseEventHandler() { - @Override - public void onEvent(SseEvent event) { - if (isClosing) { - return; - } - - try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); - } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); - } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); - } - } - catch (IOException e) { - logger.error("Error processing SSE event", e); - future.completeExceptionally(e); - } - } - - @Override - public void onError(Throwable error) { - if (!isClosing) { - logger.error("SSE connection error", error); - future.completeExceptionally(error); - } - } - }); - - return Mono.fromFuture(future); - } - - /** - * Sends a JSON-RPC message to the server. - * - *

    - * This method waits for the message endpoint to be discovered before sending the - * message. The message is serialized to JSON and sent as an HTTP POST request. - * @param message the JSON-RPC message to send - * @return a Mono that completes when the message is sent - * @throws McpError if the message endpoint is not available or the wait times out - */ - @Override - public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { - return Mono.empty(); - } - - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } - - try { - String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(this.baseUri + endpoint)) - .header("Content-Type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); - - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); - } - } - - /** - * Gracefully closes the transport connection. - * - *

    - * Sets the closing flag and cancels any pending connection future. This prevents new - * messages from being sent and allows ongoing operations to complete. - * @return a Mono that completes when the closing process is initiated - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); - } - }); - } - - /** - * Unmarshals data to the specified type using the configured object mapper. - * @param data the data to unmarshal - * @param typeRef the type reference for the target type - * @param the target type - * @return the unmarshalled object - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java deleted file mode 100644 index 7b6916785..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ /dev/null @@ -1,678 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.spec.DefaultMcpSession; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; -import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * The Model Context Protocol (MCP) server implementation that provides asynchronous - * communication using Project Reactor's Mono and Flux types. - * - *

    - * This server implements the MCP specification, enabling AI models to expose tools, - * resources, and prompts through a standardized interface. Key features include: - *

      - *
    • Asynchronous communication using reactive programming patterns - *
    • Dynamic tool registration and management - *
    • Resource handling with URI-based addressing - *
    • Prompt template management - *
    • Real-time client notifications for state changes - *
    • Structured logging with configurable severity levels - *
    • Support for client-side AI model sampling - *
    - * - *

    - * The server follows a lifecycle: - *

      - *
    1. Initialization - Accepts client connections and negotiates capabilities - *
    2. Normal Operation - Handles client requests and sends notifications - *
    3. Graceful Shutdown - Ensures clean connection termination - *
    - * - *

    - * This implementation uses Project Reactor for non-blocking operations, making it - * suitable for high-throughput scenarios and reactive applications. All operations return - * Mono or Flux types that can be composed into reactive pipelines. - * - *

    - * The server supports runtime modification of its capabilities through methods like - * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying - * connected clients of changes when configured to do so. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - * @see McpServer - * @see McpSchema - * @see DefaultMcpSession - */ -public class McpAsyncServer { - - private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final DefaultMcpSession mcpSession; - - private final ServerMcpTransport transport; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private McpSchema.ClientCapabilities clientCapabilities; - - private McpSchema.Implementation clientInfo; - - /** - * Thread-safe list of tool handlers that can be modified at runtime. - */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - /** - * Supported protocol versions. - */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - /** - * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransport The transport layer implementation for MCP communication. - * @param features The MCP server supported features. - */ - McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - this.transport = mcpTransport; - this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, - notificationHandlers); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() { - return params -> { - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - this.clientCapabilities = initializeRequest.capabilities(); - this.clientInfo = initializeRequest.clientInfo(); - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, null)); - }; - } - - /** - * Get the server capabilities that define the supported features and functionality. - * @return The server capabilities - */ - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - /** - * Get the server implementation information. - * @return The server implementation details - */ - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - - /** - * Get the client capabilities that define the supported features and functionality. - * @return The client capabilities - */ - public ClientCapabilities getClientCapabilities() { - return this.clientCapabilities; - } - - /** - * Get the client implementation information. - * @return The client implementation details - */ - public McpSchema.Implementation getClientInfo() { - return this.clientInfo; - } - - /** - * Gracefully closes the server, allowing any in-progress operations to complete. - * @return A Mono that completes when the server has been closed - */ - public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); - } - - /** - * Close the server immediately. - */ - public void close() { - this.mcpSession.close(); - } - - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Retrieves the list of all roots provided by the client. - * @return A Mono that emits the list of roots result. - */ - public Mono listRoots() { - return this.listRoots(null); - } - - /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return A Mono that emits the list of roots result containing - */ - public Mono listRoots(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_ROOTS_RESULT_TYPE_REF); - } - - private NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - /** - * Add a new tool registration at runtime. - * @param toolRegistration The tool registration to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); - } - if (toolRegistration.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolRegistration.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); - } - - this.tools.add(toolRegistration); - logger.debug("Added tool handler: {}", toolRegistration.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a tool handler at runtime. - * @param toolName The name of the tool handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available tools has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyToolsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler toolsListRequestHandler() { - return params -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - - private DefaultMcpSession.RequestHandler toolsCallRequestHandler() { - return params -> { - McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); - - if (toolRegistration.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } - - return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } - - // --------------------------------------- - // Resource Management - // --------------------------------------- - - /** - * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a resource handler at runtime. - * @param resourceUri The URI of the resource handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - /** - * Notifies clients that the list of available resources has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { - return params -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() { - return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); - - } - - private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() { - return params -> { - McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); - if (registration != null) { - return registration.readHandler().apply(resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- - - /** - * Add a new prompt handler at runtime. - * @param promptRegistration The prompt handler to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - if (promptRegistration == null) { - return Mono.error(new McpError("Prompt registration must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); - if (registration != null) { - return Mono.error( - new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a prompt handler at runtime. - * @param promptName The name of the prompt handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available prompts has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyPromptsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler promptsListRequestHandler() { - return params -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private DefaultMcpSession.RequestHandler promptsGetRequestHandler() { - return params -> { - McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); - if (registration == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } - - return registration.promptHandler().apply(promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - /** - * Send a logging message notification to all connected clients. Messages below the - * current minimum logging level will be filtered out. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - */ - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - Map params = this.transport.unmarshalFrom(loggingMessageNotification, - new TypeReference>() { - }); - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); - } - - /** - * Handles requests to set the minimum logging level. Messages below this level will - * not be sent. - * @return A handler that processes logging level change requests - */ - private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { - return params -> { - this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { - }); - - return Mono.empty(); - }; - } - - // --------------------------------------- - // Sampling - // --------------------------------------- - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. This - * flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. - * @param createMessageRequest The request to create a new message - * @return A Mono that completes when the message has been created - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - */ - public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - - if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); - } - if (this.clientCapabilities.sampling() == null) { - return Mono.error(new McpError("Client must be configured with sampling capabilities")); - } - return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); - } - - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java deleted file mode 100644 index 54c7a28fd..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ /dev/null @@ -1,897 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; -import io.modelcontextprotocol.util.Assert; -import reactor.core.publisher.Mono; - -/** - * Factory class for creating Model Context Protocol (MCP) servers. MCP servers expose - * tools, resources, and prompts to AI models through a standardized interface. - * - *

    - * This class serves as the main entry point for implementing the server-side of the MCP - * specification. The server's responsibilities include: - *

      - *
    • Exposing tools that models can invoke to perform actions - *
    • Providing access to resources that give models context - *
    • Managing prompt templates for structured model interactions - *
    • Handling client connections and requests - *
    • Implementing capability negotiation - *
    - * - *

    - * Thread Safety: Both synchronous and asynchronous server implementations are - * thread-safe. The synchronous server processes requests sequentially, while the - * asynchronous server can handle concurrent requests safely through its reactive - * programming model. - * - *

    - * Error Handling: The server implementations provide robust error handling through the - * McpError class. Errors are properly propagated to clients while maintaining the - * server's stability. Server implementations should use appropriate error codes and - * provide meaningful error messages to help diagnose issues. - * - *

    - * The class provides factory methods to create either: - *

      - *
    • {@link McpAsyncServer} for non-blocking operations with CompletableFuture responses - *
    • {@link McpSyncServer} for blocking operations with direct responses - *
    - * - *

    - * Example of creating a basic synchronous server:

    {@code
    - * McpServer.sync(transport)
    - *     .serverInfo("my-server", "1.0.0")
    - *     .tool(new Tool("calculator", "Performs calculations", schema),
    - *           args -> new CallToolResult("Result: " + calculate(args)))
    - *     .build();
    - * }
    - * - * Example of creating a basic asynchronous server:
    {@code
    - * McpServer.async(transport)
    - *     .serverInfo("my-server", "1.0.0")
    - *     .tool(new Tool("calculator", "Performs calculations", schema),
    - *           args -> Mono.just(new CallToolResult("Result: " + calculate(args))))
    - *     .build();
    - * }
    - * - *

    - * Example with comprehensive asynchronous configuration:

    {@code
    - * McpServer.async(transport)
    - *     .serverInfo("advanced-server", "2.0.0")
    - *     .capabilities(new ServerCapabilities(...))
    - *     // Register tools
    - *     .tools(
    - *         new McpServerFeatures.AsyncToolRegistration(calculatorTool,
    - *             args -> Mono.just(new CallToolResult("Result: " + calculate(args)))),
    - *         new McpServerFeatures.AsyncToolRegistration(weatherTool,
    - *             args -> Mono.just(new CallToolResult("Weather: " + getWeather(args))))
    - *     )
    - *     // Register resources
    - *     .resources(
    - *         new McpServerFeatures.AsyncResourceRegistration(fileResource,
    - *             req -> Mono.just(new ReadResourceResult(readFile(req)))),
    - *         new McpServerFeatures.AsyncResourceRegistration(dbResource,
    - *             req -> Mono.just(new ReadResourceResult(queryDb(req))))
    - *     )
    - *     // Add resource templates
    - *     .resourceTemplates(
    - *         new ResourceTemplate("file://{path}", "Access files"),
    - *         new ResourceTemplate("db://{table}", "Access database")
    - *     )
    - *     // Register prompts
    - *     .prompts(
    - *         new McpServerFeatures.AsyncPromptRegistration(analysisPrompt,
    - *             req -> Mono.just(new GetPromptResult(generateAnalysisPrompt(req)))),
    - *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
    - *             req -> Mono.just(new GetPromptResult(generateSummaryPrompt(req))))
    - *     )
    - *     .build();
    - * }
    - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - * @see McpAsyncServer - * @see McpSyncServer - * @see McpTransport - */ -public interface McpServer { - - /** - * Starts building a synchronous MCP server that provides blocking operations. - * Synchronous servers process each request to completion before handling the next - * one, making them simpler to implement but potentially less performant for - * concurrent operations. - * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link SyncSpec} for configuring the server. - */ - static SyncSpec sync(ServerMcpTransport transport) { - return new SyncSpec(transport); - } - - /** - * Starts building an asynchronous MCP server that provides blocking operations. - * Asynchronous servers can handle multiple requests concurrently using a functional - * paradigm with non-blocking server transports, making them more efficient for - * high-concurrency scenarios but more complex to implement. - * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link SyncSpec} for configuring the server. - */ - static AsyncSpec async(ServerMcpTransport transport) { - return new AsyncSpec(transport); - } - - /** - * Asynchronous server specification. - */ - class AsyncSpec { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private final ServerMcpTransport transport; - - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * The Model Context Protocol (MCP) allows servers to expose tools that can be - * invoked by language models. Tools enable models to interact with external - * systems, such as querying databases, calling APIs, or performing computations. - * Each tool is uniquely identified by a name and includes metadata describing its - * schema. - */ - private final List tools = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose resources to clients. Resources allow servers to share data that - * provides context to language models, such as files, database schemas, or - * application-specific information. Each resource is uniquely identified by a - * URI. - */ - private final Map resources = new HashMap<>(); - - private final List resourceTemplates = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose prompt templates to clients. Prompts allow servers to provide structured - * messages and instructions for interacting with language models. Clients can - * discover available prompts, retrieve their contents, and provide arguments to - * customize them. - */ - private final Map prompts = new HashMap<>(); - - private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); - - private AsyncSpec(ServerMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; - } - - /** - * Sets the server implementation information that will be shared with clients - * during connection initialization. This helps with version compatibility, - * debugging, and server identification. - * @param serverInfo The server implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverInfo is null - */ - public AsyncSpec serverInfo(McpSchema.Implementation serverInfo) { - Assert.notNull(serverInfo, "Server info must not be null"); - this.serverInfo = serverInfo; - return this; - } - - /** - * Sets the server implementation information using name and version strings. This - * is a convenience method alternative to - * {@link #serverInfo(McpSchema.Implementation)}. - * @param name The server name. Must not be null or empty. - * @param version The server version. Must not be null or empty. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if name or version is null or empty - * @see #serverInfo(McpSchema.Implementation) - */ - public AsyncSpec serverInfo(String name, String version) { - Assert.hasText(name, "Name must not be null or empty"); - Assert.hasText(version, "Version must not be null or empty"); - this.serverInfo = new McpSchema.Implementation(name, version); - return this; - } - - /** - * Sets the server capabilities that will be advertised to clients during - * connection initialization. Capabilities define what features the server - * supports, such as: - *
      - *
    • Tool execution - *
    • Resource access - *
    • Prompt handling - *
    • Streaming responses - *
    • Batch operations - *
    - * @param serverCapabilities The server capabilities configuration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverCapabilities is null - */ - public AsyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { - this.serverCapabilities = serverCapabilities; - return this; - } - - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.AsyncToolRegistration} explicitly. - * - *

    - * Example usage:

    {@code
    -		 * .tool(
    -		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> Mono.just(new CallToolResult("Result: " + calculate(args)))
    -		 * )
    -		 * }
    - * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - */ - public AsyncSpec tool(McpSchema.Tool tool, Function, Mono> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - - this.tools.add(new McpServerFeatures.AsyncToolRegistration(tool, handler)); - - return this; - } - - /** - * Adds multiple tools with their handlers to the server using a List. This method - * is useful when tools are dynamically generated or loaded from a configuration - * source. - * @param toolRegistrations The list of tool registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(McpServerFeatures.AsyncToolRegistration...) - */ - public AsyncSpec tools(List toolRegistrations) { - Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); - this.tools.addAll(toolRegistrations); - return this; - } - - /** - * Adds multiple tools with their handlers to the server using varargs. This - * method provides a convenient way to register multiple tools inline. - * - *

    - * Example usage:

    {@code
    -		 * .tools(
    -		 *     new McpServerFeatures.AsyncToolRegistration(calculatorTool, calculatorHandler),
    -		 *     new McpServerFeatures.AsyncToolRegistration(weatherTool, weatherHandler),
    -		 *     new McpServerFeatures.AsyncToolRegistration(fileManagerTool, fileManagerHandler)
    -		 * )
    -		 * }
    - * @param toolRegistrations The tool registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(List) - */ - public AsyncSpec tools(McpServerFeatures.AsyncToolRegistration... toolRegistrations) { - for (McpServerFeatures.AsyncToolRegistration tool : toolRegistrations) { - this.tools.add(tool); - } - return this; - } - - /** - * Registers multiple resources with their handlers using a Map. This method is - * useful when resources are dynamically generated or loaded from a configuration - * source. - * @param resourceRegsitrations Map of resource name to registration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.AsyncResourceRegistration...) - */ - public AsyncSpec resources(Map resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); - this.resources.putAll(resourceRegsitrations); - return this; - } - - /** - * Registers multiple resources with their handlers using a List. This method is - * useful when resources need to be added in bulk from a collection. - * @param resourceRegsitrations List of resource registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.AsyncResourceRegistration...) - */ - public AsyncSpec resources(List resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegsitrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Registers multiple resources with their handlers using varargs. This method - * provides a convenient way to register multiple resources inline. - * - *

    - * Example usage:

    {@code
    -		 * .resources(
    -		 *     new McpServerFeatures.AsyncResourceRegistration(fileResource, fileHandler),
    -		 *     new McpServerFeatures.AsyncResourceRegistration(dbResource, dbHandler),
    -		 *     new McpServerFeatures.AsyncResourceRegistration(apiResource, apiHandler)
    -		 * )
    -		 * }
    - * @param resourceRegistrations The resource registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegistrations is null - */ - public AsyncSpec resources(McpServerFeatures.AsyncResourceRegistration... resourceRegistrations) { - Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegistrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

    - * Example usage:

    {@code
    -		 * .resourceTemplates(
    -		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    -		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    -		 * )
    -		 * }
    - * @param resourceTemplates List of resource templates. If null, clears existing - * templates. - * @return This builder instance for method chaining - * @see #resourceTemplates(ResourceTemplate...) - */ - public AsyncSpec resourceTemplates(List resourceTemplates) { - this.resourceTemplates.addAll(resourceTemplates); - return this; - } - - /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. - * @return This builder instance for method chaining - * @see #resourceTemplates(List) - */ - public AsyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using a Map. This method is - * useful when prompts are dynamically generated or loaded from a configuration - * source. - * - *

    - * Example usage:

    {@code
    -		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptRegistration(
    -		 *     new Prompt("analysis", "Code analysis template"),
    -		 *     request -> Mono.just(new GetPromptResult(generateAnalysisPrompt(request)))
    -		 * )));
    -		 * }
    - * @param prompts Map of prompt name to registration. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public AsyncSpec prompts(Map prompts) { - this.prompts.putAll(prompts); - return this; - } - - /** - * Registers multiple prompts with their handlers using a List. This method is - * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.AsyncPromptRegistration...) - */ - public AsyncSpec prompts(List prompts) { - for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using varargs. This method - * provides a convenient way to register multiple prompts inline. - * - *

    - * Example usage:

    {@code
    -		 * .prompts(
    -		 *     new McpServerFeatures.AsyncPromptRegistration(analysisPrompt, analysisHandler),
    -		 *     new McpServerFeatures.AsyncPromptRegistration(summaryPrompt, summaryHandler),
    -		 *     new McpServerFeatures.AsyncPromptRegistration(reviewPrompt, reviewHandler)
    -		 * )
    -		 * }
    - * @param prompts The prompt registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { - for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers a consumer that will be notified when the list of roots changes. This - * is useful for updating resource availability dynamically, such as when new - * files are added or removed. - * @param consumer The consumer to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumer is null - */ - public AsyncSpec rootsChangeConsumer(Function, Mono> consumer) { - Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes. This method is useful when multiple consumers need to be registered at - * once. - * @param consumers The list of consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public AsyncSpec rootsChangeConsumers(List, Mono>> consumers) { - Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes using varargs. This method provides a convenient way to register - * multiple consumers inline. - * @param consumers The consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public AsyncSpec rootsChangeConsumers( - @SuppressWarnings("unchecked") Function, Mono>... consumers) { - for (Function, Mono> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } - return this; - } - - /** - * Builds an asynchronous MCP server that provides non-blocking operations. - * @return A new instance of {@link McpAsyncServer} configured with this builder's - * settings - */ - public McpAsyncServer build() { - return new McpAsyncServer(this.transport, - new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, - this.resourceTemplates, this.prompts, this.rootsChangeConsumers)); - } - - } - - /** - * Synchronous server specification. - */ - class SyncSpec { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private final ServerMcpTransport transport; - - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * The Model Context Protocol (MCP) allows servers to expose tools that can be - * invoked by language models. Tools enable models to interact with external - * systems, such as querying databases, calling APIs, or performing computations. - * Each tool is uniquely identified by a name and includes metadata describing its - * schema. - */ - private final List tools = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose resources to clients. Resources allow servers to share data that - * provides context to language models, such as files, database schemas, or - * application-specific information. Each resource is uniquely identified by a - * URI. - */ - private final Map resources = new HashMap<>(); - - private final List resourceTemplates = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose prompt templates to clients. Prompts allow servers to provide structured - * messages and instructions for interacting with language models. Clients can - * discover available prompts, retrieve their contents, and provide arguments to - * customize them. - */ - private final Map prompts = new HashMap<>(); - - private final List>> rootsChangeConsumers = new ArrayList<>(); - - private SyncSpec(ServerMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; - } - - /** - * Sets the server implementation information that will be shared with clients - * during connection initialization. This helps with version compatibility, - * debugging, and server identification. - * @param serverInfo The server implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverInfo is null - */ - public SyncSpec serverInfo(McpSchema.Implementation serverInfo) { - Assert.notNull(serverInfo, "Server info must not be null"); - this.serverInfo = serverInfo; - return this; - } - - /** - * Sets the server implementation information using name and version strings. This - * is a convenience method alternative to - * {@link #serverInfo(McpSchema.Implementation)}. - * @param name The server name. Must not be null or empty. - * @param version The server version. Must not be null or empty. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if name or version is null or empty - * @see #serverInfo(McpSchema.Implementation) - */ - public SyncSpec serverInfo(String name, String version) { - Assert.hasText(name, "Name must not be null or empty"); - Assert.hasText(version, "Version must not be null or empty"); - this.serverInfo = new McpSchema.Implementation(name, version); - return this; - } - - /** - * Sets the server capabilities that will be advertised to clients during - * connection initialization. Capabilities define what features the server - * supports, such as: - *
      - *
    • Tool execution - *
    • Resource access - *
    • Prompt handling - *
    • Streaming responses - *
    • Batch operations - *
    - * @param serverCapabilities The server capabilities configuration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverCapabilities is null - */ - public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { - this.serverCapabilities = serverCapabilities; - return this; - } - - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link ToolRegistration} explicitly. - * - *

    - * Example usage:

    {@code
    -		 * .tool(
    -		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> new CallToolResult("Result: " + calculate(args))
    -		 * )
    -		 * }
    - * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - */ - public SyncSpec tool(McpSchema.Tool tool, Function, McpSchema.CallToolResult> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - - this.tools.add(new McpServerFeatures.SyncToolRegistration(tool, handler)); - - return this; - } - - /** - * Adds multiple tools with their handlers to the server using a List. This method - * is useful when tools are dynamically generated or loaded from a configuration - * source. - * @param toolRegistrations The list of tool registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(McpServerFeatures.SyncToolRegistration...) - */ - public SyncSpec tools(List toolRegistrations) { - Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); - this.tools.addAll(toolRegistrations); - return this; - } - - /** - * Adds multiple tools with their handlers to the server using varargs. This - * method provides a convenient way to register multiple tools inline. - * - *

    - * Example usage:

    {@code
    -		 * .tools(
    -		 *     new ToolRegistration(calculatorTool, calculatorHandler),
    -		 *     new ToolRegistration(weatherTool, weatherHandler),
    -		 *     new ToolRegistration(fileManagerTool, fileManagerHandler)
    -		 * )
    -		 * }
    - * @param toolRegistrations The tool registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(List) - */ - public SyncSpec tools(McpServerFeatures.SyncToolRegistration... toolRegistrations) { - for (McpServerFeatures.SyncToolRegistration tool : toolRegistrations) { - this.tools.add(tool); - } - return this; - } - - /** - * Registers multiple resources with their handlers using a Map. This method is - * useful when resources are dynamically generated or loaded from a configuration - * source. - * @param resourceRegsitrations Map of resource name to registration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.SyncResourceRegistration...) - */ - public SyncSpec resources(Map resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); - this.resources.putAll(resourceRegsitrations); - return this; - } - - /** - * Registers multiple resources with their handlers using a List. This method is - * useful when resources need to be added in bulk from a collection. - * @param resourceRegsitrations List of resource registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.SyncResourceRegistration...) - */ - public SyncSpec resources(List resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceRegistration resource : resourceRegsitrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Registers multiple resources with their handlers using varargs. This method - * provides a convenient way to register multiple resources inline. - * - *

    - * Example usage:

    {@code
    -		 * .resources(
    -		 *     new ResourceRegistration(fileResource, fileHandler),
    -		 *     new ResourceRegistration(dbResource, dbHandler),
    -		 *     new ResourceRegistration(apiResource, apiHandler)
    -		 * )
    -		 * }
    - * @param resourceRegistrations The resource registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegistrations is null - */ - public SyncSpec resources(McpServerFeatures.SyncResourceRegistration... resourceRegistrations) { - Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceRegistration resource : resourceRegistrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

    - * Example usage:

    {@code
    -		 * .resourceTemplates(
    -		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    -		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    -		 * )
    -		 * }
    - * @param resourceTemplates List of resource templates. If null, clears existing - * templates. - * @return This builder instance for method chaining - * @see #resourceTemplates(ResourceTemplate...) - */ - public SyncSpec resourceTemplates(List resourceTemplates) { - this.resourceTemplates.addAll(resourceTemplates); - return this; - } - - /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. - * @return This builder instance for method chaining - * @see #resourceTemplates(List) - */ - public SyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using a Map. This method is - * useful when prompts are dynamically generated or loaded from a configuration - * source. - * - *

    - * Example usage:

    {@code
    -		 * Map prompts = new HashMap<>();
    -		 * prompts.put("analysis", new PromptRegistration(
    -		 *     new Prompt("analysis", "Code analysis template"),
    -		 *     request -> new GetPromptResult(generateAnalysisPrompt(request))
    -		 * ));
    -		 * .prompts(prompts)
    -		 * }
    - * @param prompts Map of prompt name to registration. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public SyncSpec prompts(Map prompts) { - this.prompts.putAll(prompts); - return this; - } - - /** - * Registers multiple prompts with their handlers using a List. This method is - * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.SyncPromptRegistration...) - */ - public SyncSpec prompts(List prompts) { - for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using varargs. This method - * provides a convenient way to register multiple prompts inline. - * - *

    - * Example usage:

    {@code
    -		 * .prompts(
    -		 *     new PromptRegistration(analysisPrompt, analysisHandler),
    -		 *     new PromptRegistration(summaryPrompt, summaryHandler),
    -		 *     new PromptRegistration(reviewPrompt, reviewHandler)
    -		 * )
    -		 * }
    - * @param prompts The prompt registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { - for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers a consumer that will be notified when the list of roots changes. This - * is useful for updating resource availability dynamically, such as when new - * files are added or removed. - * @param consumer The consumer to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumer is null - */ - public SyncSpec rootsChangeConsumer(Consumer> consumer) { - Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes. This method is useful when multiple consumers need to be registered at - * once. - * @param consumers The list of consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public SyncSpec rootsChangeConsumers(List>> consumers) { - Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes using varargs. This method provides a convenient way to register - * multiple consumers inline. - * @param consumers The consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public SyncSpec rootsChangeConsumers(Consumer>... consumers) { - for (Consumer> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } - return this; - } - - /** - * Builds a synchronous MCP server that provides blocking operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings - */ - public McpSyncServer build() { - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers); - return new McpSyncServer( - new McpAsyncServer(this.transport, McpServerFeatures.Async.fromSync(syncFeatures))); - } - - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java deleted file mode 100644 index c8f8399ab..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -/** - * MCP server features specification that a particular server can choose to support. - * - * @author Dariusz Jędrzejczyk - */ -public class McpServerFeatures { - - /** - * Asynchronous server features specification. - * - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when the - * roots list changes - */ - record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, - List resourceTemplates, - Map prompts, - List, Mono>> rootsChangeConsumers) { - - /** - * Create an instance and validate the arguments. - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when - * the roots list changes - */ - Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, - List resourceTemplates, - Map prompts, - List, Mono>> rootsChangeConsumers) { - - Assert.notNull(serverInfo, "Server info must not be null"); - - this.serverInfo = serverInfo; - this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental - new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default - !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, - !Utils.isEmpty(resources) - ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); - - this.tools = (tools != null) ? tools : List.of(); - this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); - this.prompts = (prompts != null) ? prompts : Map.of(); - this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); - } - - /** - * Convert a synchronous specification into an asynchronous one and provide - * blocking code offloading to prevent accidental blocking of the non-blocking - * transport. - * @param syncSpec a potentially blocking, synchronous specification. - * @return a specification which is protected from blocking calls specified by the - * user. - */ - static Async fromSync(Sync syncSpec) { - List tools = new ArrayList<>(); - for (var tool : syncSpec.tools()) { - tools.add(AsyncToolRegistration.fromSync(tool)); - } - - Map resources = new HashMap<>(); - syncSpec.resources().forEach((key, resource) -> { - resources.put(key, AsyncResourceRegistration.fromSync(resource)); - }); - - Map prompts = new HashMap<>(); - syncSpec.prompts().forEach((key, prompt) -> { - prompts.put(key, AsyncPromptRegistration.fromSync(prompt)); - }); - - List, Mono>> rootChangeConsumers = new ArrayList<>(); - - for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { - rootChangeConsumers.add(list -> Mono.fromRunnable(() -> rootChangeConsumer.accept(list)) - .subscribeOn(Schedulers.boundedElastic())); - } - - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, rootChangeConsumers); - } - } - - /** - * Synchronous server features specification. - * - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when the - * roots list changes - */ - record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, - Map resources, - List resourceTemplates, - Map prompts, - List>> rootsChangeConsumers) { - - /** - * Create an instance and validate the arguments. - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when - * the roots list changes - */ - Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, - Map resources, - List resourceTemplates, - Map prompts, - List>> rootsChangeConsumers) { - - Assert.notNull(serverInfo, "Server info must not be null"); - - this.serverInfo = serverInfo; - this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental - new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default - !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, - !Utils.isEmpty(resources) - ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); - - this.tools = (tools != null) ? tools : new ArrayList<>(); - this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); - this.prompts = (prompts != null) ? prompts : new HashMap<>(); - this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); - } - - } - - /** - * Registration of a tool with its asynchronous handler function. Tools are the - * primary way for MCP servers to expose functionality to AI models. Each tool - * represents a specific capability, such as: - *
      - *
    • Performing calculations - *
    • Accessing external APIs - *
    • Querying databases - *
    • Manipulating files - *
    • Executing system commands - *
    - * - *

    - * Example tool registration:

    {@code
    -	 * new McpServerFeatures.AsyncToolRegistration(
    -	 *     new Tool(
    -	 *         "calculator",
    -	 *         "Performs mathematical calculations",
    -	 *         new JsonSchemaObject()
    -	 *             .required("expression")
    -	 *             .property("expression", JsonSchemaType.STRING)
    -	 *     ),
    -	 *     args -> {
    -	 *         String expr = (String) args.get("expression");
    -	 *         return Mono.just(new CallToolResult("Result: " + evaluate(expr)));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - */ - public record AsyncToolRegistration(McpSchema.Tool tool, - Function, Mono> call) { - - static AsyncToolRegistration fromSync(SyncToolRegistration tool) { - // FIXME: This is temporary, proper validation should be implemented - if (tool == null) { - return null; - } - return new AsyncToolRegistration(tool.tool(), - map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); - } - } - - /** - * Registration of a resource with its asynchronous handler function. Resources - * provide context to AI models by exposing data such as: - *
      - *
    • File contents - *
    • Database records - *
    • API responses - *
    • System information - *
    • Application state - *
    - * - *

    - * Example resource registration:

    {@code
    -	 * new McpServerFeatures.AsyncResourceRegistration(
    -	 *     new Resource("docs", "Documentation files", "text/markdown"),
    -	 *     request -> {
    -	 *         String content = readFile(request.getPath());
    -	 *         return Mono.just(new ReadResourceResult(content));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - */ - public record AsyncResourceRegistration(McpSchema.Resource resource, - Function> readHandler) { - - static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { - // FIXME: This is temporary, proper validation should be implemented - if (resource == null) { - return null; - } - return new AsyncResourceRegistration(resource.resource(), - req -> Mono.fromCallable(() -> resource.readHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } - } - - /** - * Registration of a prompt template with its asynchronous handler function. Prompts - * provide structured templates for AI model interactions, supporting: - *
      - *
    • Consistent message formatting - *
    • Parameter substitution - *
    • Context injection - *
    • Response formatting - *
    • Instruction templating - *
    - * - *

    - * Example prompt registration:

    {@code
    -	 * new McpServerFeatures.AsyncPromptRegistration(
    -	 *     new Prompt("analyze", "Code analysis template"),
    -	 *     request -> {
    -	 *         String code = request.getArguments().get("code");
    -	 *         return Mono.just(new GetPromptResult(
    -	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    -	 *         ));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - */ - public record AsyncPromptRegistration(McpSchema.Prompt prompt, - Function> promptHandler) { - - static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { - // FIXME: This is temporary, proper validation should be implemented - if (prompt == null) { - return null; - } - return new AsyncPromptRegistration(prompt.prompt(), - req -> Mono.fromCallable(() -> prompt.promptHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } - } - - /** - * Registration of a tool with its synchronous handler function. Tools are the primary - * way for MCP servers to expose functionality to AI models. Each tool represents a - * specific capability, such as: - *
      - *
    • Performing calculations - *
    • Accessing external APIs - *
    • Querying databases - *
    • Manipulating files - *
    • Executing system commands - *
    - * - *

    - * Example tool registration:

    {@code
    -	 * new McpServerFeatures.SyncToolRegistration(
    -	 *     new Tool(
    -	 *         "calculator",
    -	 *         "Performs mathematical calculations",
    -	 *         new JsonSchemaObject()
    -	 *             .required("expression")
    -	 *             .property("expression", JsonSchemaType.STRING)
    -	 *     ),
    -	 *     args -> {
    -	 *         String expr = (String) args.get("expression");
    -	 *         return new CallToolResult("Result: " + evaluate(expr));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - */ - public record SyncToolRegistration(McpSchema.Tool tool, - Function, McpSchema.CallToolResult> call) { - } - - /** - * Registration of a resource with its synchronous handler function. Resources provide - * context to AI models by exposing data such as: - *
      - *
    • File contents - *
    • Database records - *
    • API responses - *
    • System information - *
    • Application state - *
    - * - *

    - * Example resource registration:

    {@code
    -	 * new McpServerFeatures.SyncResourceRegistration(
    -	 *     new Resource("docs", "Documentation files", "text/markdown"),
    -	 *     request -> {
    -	 *         String content = readFile(request.getPath());
    -	 *         return new ReadResourceResult(content);
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - */ - public record SyncResourceRegistration(McpSchema.Resource resource, - Function readHandler) { - } - - /** - * Registration of a prompt template with its synchronous handler function. Prompts - * provide structured templates for AI model interactions, supporting: - *
      - *
    • Consistent message formatting - *
    • Parameter substitution - *
    • Context injection - *
    • Response formatting - *
    • Instruction templating - *
    - * - *

    - * Example prompt registration:

    {@code
    -	 * new McpServerFeatures.SyncPromptRegistration(
    -	 *     new Prompt("analyze", "Code analysis template"),
    -	 *     request -> {
    -	 *         String code = request.getArguments().get("code");
    -	 *         return new GetPromptResult(
    -	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    -	 *         );
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - */ - public record SyncPromptRegistration(McpSchema.Prompt prompt, - Function promptHandler) { - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java deleted file mode 100644 index 98b8ea582..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java +++ /dev/null @@ -1,416 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.server.transport; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import jakarta.servlet.AsyncContext; -import jakarta.servlet.ServletException; -import jakarta.servlet.annotation.WebServlet; -import jakarta.servlet.http.HttpServlet; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -/** - * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport - * specification. This implementation provides similar functionality to - * WebFluxSseServerTransport but uses the traditional Servlet API instead of WebFlux. - * - *

    - * The transport handles two types of endpoints: - *

      - *
    • SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client - * events
    • - *
    • Message endpoint (configurable) - Handles client-to-server message requests
    • - *
    - * - *

    - * Features: - *

      - *
    • Asynchronous message handling using Servlet 6.0 async support
    • - *
    • Session management for multiple client connections
    • - *
    • Graceful shutdown support
    • - *
    • Error handling and response formatting
    • - *
    - * - * @author Christian Tzolov - * @author Alexandros Pappas - * @see ServerMcpTransport - * @see HttpServlet - */ - -@WebServlet(asyncSupported = true) -public class HttpServletSseServerTransport extends HttpServlet implements ServerMcpTransport { - - /** Logger for this class */ - private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransport.class); - - public static final String UTF_8 = "UTF-8"; - - public static final String APPLICATION_JSON = "application/json"; - - public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - - /** Default endpoint path for SSE connections */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - /** Event type for regular messages */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** Event type for endpoint information */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** JSON object mapper for serialization/deserialization */ - private final ObjectMapper objectMapper; - - /** The endpoint path for handling client messages */ - private final String messageEndpoint; - - /** The endpoint path for handling SSE connections */ - private final String sseEndpoint; - - /** Map of active client sessions, keyed by session ID */ - private final Map sessions = new ConcurrentHashMap<>(); - - /** Flag indicating if the transport is in the process of shutting down */ - private final AtomicBoolean isClosing = new AtomicBoolean(false); - - /** Handler for processing incoming messages */ - private Function, Mono> connectHandler; - - /** - * Creates a new HttpServletSseServerTransport instance with a custom SSE endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - */ - public HttpServletSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this.objectMapper = objectMapper; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - } - - /** - * Creates a new HttpServletSseServerTransport instance with the default SSE endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - */ - public HttpServletSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Handles GET requests to establish SSE connections. - *

    - * This method sets up a new SSE connection when a client connects to the SSE - * endpoint. It configures the response headers for SSE, creates a new session, and - * sends the initial endpoint information to the client. - * @param request The HTTP servlet request - * @param response The HTTP servlet response - * @throws ServletException If a servlet-specific error occurs - * @throws IOException If an I/O error occurs - */ - @Override - protected void doGet(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - String pathInfo = request.getPathInfo(); - if (!sseEndpoint.equals(pathInfo)) { - response.sendError(HttpServletResponse.SC_NOT_FOUND); - return; - } - - if (isClosing.get()) { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); - return; - } - - response.setContentType("text/event-stream"); - response.setCharacterEncoding(UTF_8); - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); - - String sessionId = UUID.randomUUID().toString(); - AsyncContext asyncContext = request.startAsync(); - asyncContext.setTimeout(0); - - PrintWriter writer = response.getWriter(); - ClientSession session = new ClientSession(sessionId, asyncContext, writer); - this.sessions.put(sessionId, session); - - // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint); - } - - /** - * Handles POST requests for client messages. - *

    - * This method processes incoming messages from clients, routes them through the - * connect handler if configured, and sends back the appropriate response. It handles - * error cases and formats error responses according to the MCP specification. - * @param request The HTTP servlet request - * @param response The HTTP servlet response - * @throws ServletException If a servlet-specific error occurs - * @throws IOException If an I/O error occurs - */ - @Override - protected void doPost(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - if (isClosing.get()) { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); - return; - } - - String pathInfo = request.getPathInfo(); - if (!messageEndpoint.equals(pathInfo)) { - response.sendError(HttpServletResponse.SC_NOT_FOUND); - return; - } - - try { - BufferedReader reader = request.getReader(); - StringBuilder body = new StringBuilder(); - String line; - while ((line = reader.readLine()) != null) { - body.append(line); - } - - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); - - if (connectHandler != null) { - connectHandler.apply(Mono.just(message)).subscribe(responseMessage -> { - try { - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - String jsonResponse = objectMapper.writeValueAsString(responseMessage); - PrintWriter writer = response.getWriter(); - writer.write(jsonResponse); - writer.flush(); - } - catch (Exception e) { - logger.error("Error sending response: {}", e.getMessage()); - try { - response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - "Error processing response: " + e.getMessage()); - } - catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - } - } - }, error -> { - try { - logger.error("Error processing message: {}", error.getMessage()); - McpError mcpError = new McpError(error.getMessage()); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } - catch (IOException e) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e.getMessage()); - try { - response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - "Error sending error response: " + e.getMessage()); - } - catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - } - } - }); - } - else { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "No message handler configured"); - } - } - catch (Exception e) { - logger.error("Invalid message format: {}", e.getMessage()); - try { - McpError mcpError = new McpError("Invalid message format: " + e.getMessage()); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } - catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid message format"); - } - } - } - - /** - * Sets up the message handler for processing client requests. - * @param handler The function to process incoming messages and produce responses - * @return A Mono that completes when the handler is set up - */ - @Override - public Mono connect(Function, Mono> handler) { - this.connectHandler = handler; - return Mono.empty(); - } - - /** - * Broadcasts a message to all connected clients. - *

    - * This method serializes the message and sends it to all active client sessions. If a - * client is disconnected, its session is removed. - * @param message The message to broadcast - * @return A Mono that completes when the message has been sent to all clients - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - return Mono.create(sink -> { - try { - String jsonText = objectMapper.writeValueAsString(message); - - sessions.values().forEach(session -> { - try { - this.sendEvent(session.writer, MESSAGE_EVENT_TYPE, jsonText); - } - catch (IOException e) { - logger.error("Failed to send message to session {}: {}", session.id, e.getMessage()); - removeSession(session); - } - }); - - sink.success(); - } - catch (Exception e) { - logger.error("Failed to process message: {}", e.getMessage()); - sink.error(new McpError("Failed to process message: " + e.getMessage())); - } - }); - } - - /** - * Closes the transport. - *

    - * This implementation delegates to the super class's close method. - */ - @Override - public void close() { - ServerMcpTransport.super.close(); - } - - /** - * Unmarshals data from one type to another using the object mapper. - * @param The target type - * @param data The source data - * @param typeRef The type reference for the target type - * @return The unmarshaled data - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. - *

    - * This method marks the transport as closing and closes all active client sessions. - * New connection attempts will be rejected during shutdown. - * @return A Mono that completes when all sessions have been closed - */ - @Override - public Mono closeGracefully() { - isClosing.set(true); - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - - return Mono.create(sink -> { - sessions.values().forEach(this::removeSession); - sink.success(); - }); - } - - /** - * Sends an SSE event to a client. - * @param writer The writer to send the event through - * @param eventType The type of event (message or endpoint) - * @param data The event data - * @throws IOException If an error occurs while writing the event - */ - private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { - writer.write("event: " + eventType + "\n"); - writer.write("data: " + data + "\n\n"); - writer.flush(); - - if (writer.checkError()) { - throw new IOException("Client disconnected"); - } - } - - /** - * Removes a client session and completes its async context. - * @param session The session to remove - */ - private void removeSession(ClientSession session) { - sessions.remove(session.id); - session.asyncContext.complete(); - } - - /** - * Represents a client connection session. - *

    - * This class holds the necessary information about a client's SSE connection, - * including its ID, async context, and output writer. - */ - private static class ClientSession { - - private final String id; - - private final AsyncContext asyncContext; - - private final PrintWriter writer; - - ClientSession(String id, AsyncContext asyncContext, PrintWriter writer) { - this.id = id; - this.asyncContext = asyncContext; - this.writer = writer; - } - - } - - /** - * Cleans up resources when the servlet is being destroyed. - *

    - * This method ensures a graceful shutdown by closing all client connections before - * calling the parent's destroy method. - */ - @Override - public void destroy() { - closeGracefully().block(); - super.destroy(); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java deleted file mode 100644 index e375cd108..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.Executors; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; - -/** - * Implementation of the MCP Stdio transport for servers that communicates using standard - * input/output streams. Messages are exchanged as newline-delimited JSON-RPC messages - * over stdin/stdout, with errors and debug information sent to stderr. - * - * @author Christian Tzolov - */ -public class StdioServerTransport implements ServerMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(StdioServerTransport.class); - - private final Sinks.Many inboundSink; - - private final Sinks.Many outboundSink; - - private ObjectMapper objectMapper; - - /** Scheduler for handling inbound messages */ - private Scheduler inboundScheduler; - - /** Scheduler for handling outbound messages */ - private Scheduler outboundScheduler; - - private volatile boolean isClosing = false; - - private final InputStream inputStream; - - private final OutputStream outputStream; - - private final Sinks.One inboundReady = Sinks.one(); - - private final Sinks.One outboundReady = Sinks.one(); - - /** - * Creates a new StdioServerTransport with a default ObjectMapper and System streams. - */ - public StdioServerTransport() { - this(new ObjectMapper()); - } - - /** - * Creates a new StdioServerTransport with the specified ObjectMapper and System - * streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioServerTransport(ObjectMapper objectMapper) { - - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); - - this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); - this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - - this.objectMapper = objectMapper; - this.inputStream = System.in; - this.outputStream = System.out; - - // Use bounded schedulers for better resource management - this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound"); - this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); - } - - @Override - public Mono connect(Function, Mono> handler) { - return Mono.fromRunnable(() -> { - handleIncomingMessages(handler); - - // Start threads - startInboundProcessing(); - startOutboundProcessing(); - }).subscribeOn(Schedulers.boundedElastic()); - } - - private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { - this.inboundSink.asFlux() - .flatMap(message -> Mono.just(message) - .transform(inboundMessageHandler) - .contextWrite(ctx -> ctx.put("observation", "myObservation"))) - .doOnTerminate(() -> { - // The outbound processing will dispose its scheduler upon completion - this.outboundSink.tryEmitComplete(); - this.inboundScheduler.dispose(); - }) - .subscribe(); - } - - @Override - public Mono sendMessage(JSONRPCMessage message) { - return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { - if (this.outboundSink.tryEmitNext(message).isSuccess()) { - return Mono.empty(); - } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); - } - })); - } - - /** - * Starts the inbound processing thread that reads JSON-RPC messages from stdin. - * Messages are deserialized and emitted to the inbound sink. - */ - private void startInboundProcessing() { - this.inboundScheduler.schedule(() -> { - inboundReady.tryEmitValue(null); - BufferedReader reader = null; - try { - reader = new BufferedReader(new InputStreamReader(inputStream)); - while (!isClosing) { - try { - String line = reader.readLine(); - if (line == null || isClosing) { - break; - } - - logger.debug("Received JSON message: {}", line); - - try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); - if (!this.inboundSink.tryEmitNext(message).isSuccess()) { - logIfNotClosing("Failed to enqueue message"); - break; - } - } - catch (Exception e) { - logIfNotClosing("Error processing inbound message", e); - break; - } - } - catch (IOException e) { - logIfNotClosing("Error reading from stdin", e); - break; - } - } - } - catch (Exception e) { - logIfNotClosing("Error in inbound processing", e); - } - finally { - isClosing = true; - inboundSink.tryEmitComplete(); - } - }); - } - - /** - * Starts the outbound processing thread that writes JSON-RPC messages to stdout. - * Messages are serialized to JSON and written with a newline delimiter. - */ - private void startOutboundProcessing() { - Function, Flux> outboundConsumer = messages -> messages // @formatter:off - .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) - .publishOn(outboundScheduler) - .handle((message, sink) -> { - if (message != null && !isClosing) { - try { - String jsonMessage = objectMapper.writeValueAsString(message); - // Escape any embedded newlines in the JSON message as per spec - jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); - - synchronized (outputStream) { - outputStream.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); - outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); - outputStream.flush(); - } - sink.next(message); - } - catch (IOException e) { - if (!isClosing) { - logger.error("Error writing message", e); - sink.error(new RuntimeException(e)); - } - else { - logger.debug("Stream closed during shutdown", e); - } - } - } - else if (isClosing) { - sink.complete(); - } - }) - .doOnComplete(() -> { - isClosing = true; - outboundScheduler.dispose(); - }) - .doOnError(e -> { - if (!isClosing) { - logger.error("Error in outbound processing", e); - isClosing = true; - outboundScheduler.dispose(); - } - }) - .map(msg -> (JSONRPCMessage) msg); - - outboundConsumer.apply(outboundSink.asFlux()).subscribe(); - } // @formatter:on - - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - isClosing = true; - logger.debug("Initiating graceful shutdown"); - // Completing the inbound causes the outbound to be completed as well, so - // we only close the inbound. - inboundSink.tryEmitComplete(); - logger.debug("Graceful shutdown complete"); - return Mono.empty(); - }).subscribeOn(Schedulers.boundedElastic()); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - private void logIfNotClosing(String message, Exception e) { - if (!this.isClosing) { - logger.error(message, e); - } - } - - private void logIfNotClosing(String message) { - if (!this.isClosing) { - logger.error(message); - } - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java deleted file mode 100644 index 8a9b4ce02..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java +++ /dev/null @@ -1,13 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -/** - * Marker interface for the client-side MCP transport. - * - * @author Christian Tzolov - */ -public interface ClientMcpTransport extends McpTransport { - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java deleted file mode 100644 index 13e43240b..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ /dev/null @@ -1,25 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; - -public class McpError extends RuntimeException { - - private JSONRPCError jsonRpcError; - - public McpError(JSONRPCError jsonRpcError) { - super(jsonRpcError.message()); - this.jsonRpcError = jsonRpcError; - } - - public McpError(Object error) { - super(error.toString()); - } - - public JSONRPCError getJsonRpcError() { - return jsonRpcError; - } - -} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java deleted file mode 100644 index 2f5511969..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ /dev/null @@ -1,1071 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.annotation.JsonTypeInfo.As; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Based on the JSON-RPC 2.0 - * specification and the Model - * Context Protocol Schema. - * - * @author Christian Tzolov - */ -public final class McpSchema { - - private static final Logger logger = LoggerFactory.getLogger(McpSchema.class); - - private McpSchema() { - } - - public static final String LATEST_PROTOCOL_VERSION = "2024-11-05"; - - public static final String JSONRPC_VERSION = "2.0"; - - // --------------------------- - // Method Names - // --------------------------- - - // Lifecycle Methods - public static final String METHOD_INITIALIZE = "initialize"; - - public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized"; - - public static final String METHOD_PING = "ping"; - - // Tool Methods - public static final String METHOD_TOOLS_LIST = "tools/list"; - - public static final String METHOD_TOOLS_CALL = "tools/call"; - - public static final String METHOD_NOTIFICATION_TOOLS_LIST_CHANGED = "notifications/tools/list_changed"; - - // Resources Methods - public static final String METHOD_RESOURCES_LIST = "resources/list"; - - public static final String METHOD_RESOURCES_READ = "resources/read"; - - public static final String METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED = "notifications/resources/list_changed"; - - public static final String METHOD_RESOURCES_TEMPLATES_LIST = "resources/templates/list"; - - public static final String METHOD_RESOURCES_SUBSCRIBE = "resources/subscribe"; - - public static final String METHOD_RESOURCES_UNSUBSCRIBE = "resources/unsubscribe"; - - // Prompt Methods - public static final String METHOD_PROMPT_LIST = "prompts/list"; - - public static final String METHOD_PROMPT_GET = "prompts/get"; - - public static final String METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"; - - // Logging Methods - public static final String METHOD_LOGGING_SET_LEVEL = "logging/setLevel"; - - public static final String METHOD_NOTIFICATION_MESSAGE = "notifications/message"; - - // Roots Methods - public static final String METHOD_ROOTS_LIST = "roots/list"; - - public static final String METHOD_NOTIFICATION_ROOTS_LIST_CHANGED = "notifications/roots/list_changed"; - - // Sampling Methods - public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - // --------------------------- - // JSON-RPC Error Codes - // --------------------------- - /** - * Standard error codes used in MCP JSON-RPC responses. - */ - public static final class ErrorCodes { - - /** - * Invalid JSON was received by the server. - */ - public static final int PARSE_ERROR = -32700; - - /** - * The JSON sent is not a valid Request object. - */ - public static final int INVALID_REQUEST = -32600; - - /** - * The method does not exist / is not available. - */ - public static final int METHOD_NOT_FOUND = -32601; - - /** - * Invalid method parameter(s). - */ - public static final int INVALID_PARAMS = -32602; - - /** - * Internal JSON-RPC error. - */ - public static final int INTERNAL_ERROR = -32603; - - } - - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, CompleteRequest, GetPromptRequest { - - } - - private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { - }; - - /** - * Deserializes a JSON string into a JSONRPCMessage object. - * @param objectMapper The ObjectMapper instance to use for deserialization - * @param jsonText The JSON string to deserialize - * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, - * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. - * @throws IOException If there's an error during deserialization - * @throws IllegalArgumentException If the JSON structure doesn't match any known - * message type - */ - public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) - throws IOException { - - logger.debug("Received JSON message: {}", jsonText); - - var map = objectMapper.readValue(jsonText, MAP_TYPE_REF); - - // Determine message type based on specific JSON structure - if (map.containsKey("method") && map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCRequest.class); - } - else if (map.containsKey("method") && !map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCNotification.class); - } - else if (map.containsKey("result") || map.containsKey("error")) { - return objectMapper.convertValue(map, JSONRPCResponse.class); - } - - throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); - } - - // --------------------------- - // JSON-RPC Message Types - // --------------------------- - public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { - - String jsonrpc(); - - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JSONRPCRequest( // @formatter:off - @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("method") String method, - @JsonProperty("id") Object id, - @JsonProperty("params") Object params) implements JSONRPCMessage { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JSONRPCNotification( // @formatter:off - @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("method") String method, - @JsonProperty("params") Map params) implements JSONRPCMessage { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JSONRPCResponse( // @formatter:off - @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("id") Object id, - @JsonProperty("result") Object result, - @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JSONRPCError( - @JsonProperty("code") int code, - @JsonProperty("message") String message, - @JsonProperty("data") Object data) { - } - }// @formatter:on - - // --------------------------- - // Initialization - // --------------------------- - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record InitializeRequest( // @formatter:off - @JsonProperty("protocolVersion") String protocolVersion, - @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record InitializeResult( // @formatter:off - @JsonProperty("protocolVersion") String protocolVersion, - @JsonProperty("capabilities") ServerCapabilities capabilities, - @JsonProperty("serverInfo") Implementation serverInfo, - @JsonProperty("instructions") String instructions) { - } // @formatter:on - - /** - * Clients can implement additional features to enrich connected MCP servers with - * additional capabilities. These capabilities can be used to extend the functionality - * of the server, or to provide additional information to the server about the - * client's capabilities. - * - * @param experimental WIP - * @param roots define the boundaries of where servers can operate within the - * filesystem, allowing them to understand which directories and files they have - * access to. - * @param sampling Provides a standardized way for servers to request LLM sampling - * (“completions” or “generations”) from language models via clients. - * - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ClientCapabilities( // @formatter:off - @JsonProperty("experimental") Map experimental, - @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling) { - - /** - * Roots define the boundaries of where servers can operate within the filesystem, - * allowing them to understand which directories and files they have access to. - * Servers can request the list of roots from supporting clients and - * receive notifications when that list changes. - * - * @param listChanged Whether the client would send notification about roots - * has changed since the last time the server checked. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record RootCapabilities( - @JsonProperty("listChanged") Boolean listChanged) { - } - - /** - * Provides a standardized way for servers to request LLM - * sampling ("completions" or "generations") from language - * models via clients. This flow allows clients to maintain - * control over model access, selection, and permissions - * while enabling servers to leverage AI capabilities—with - * no server API keys necessary. Servers can request text or - * image-based interactions and optionally include context - * from MCP servers in their prompts. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record Sampling() { - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private Map experimental; - private RootCapabilities roots; - private Sampling sampling; - - public Builder experimental(Map experimental) { - this.experimental = experimental; - return this; - } - - public Builder roots(Boolean listChanged) { - this.roots = new RootCapabilities(listChanged); - return this; - } - - public Builder sampling() { - this.sampling = new Sampling(); - return this; - } - - public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ServerCapabilities( // @formatter:off - @JsonProperty("experimental") Map experimental, - @JsonProperty("logging") LoggingCapabilities logging, - @JsonProperty("prompts") PromptCapabilities prompts, - @JsonProperty("resources") ResourceCapabilities resources, - @JsonProperty("tools") ToolCapabilities tools) { - - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record LoggingCapabilities() { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record PromptCapabilities( - @JsonProperty("listChanged") Boolean listChanged) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record ResourceCapabilities( - @JsonProperty("subscribe") Boolean subscribe, - @JsonProperty("listChanged") Boolean listChanged) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record ToolCapabilities( - @JsonProperty("listChanged") Boolean listChanged) { - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private Map experimental; - private LoggingCapabilities logging = new LoggingCapabilities(); - private PromptCapabilities prompts; - private ResourceCapabilities resources; - private ToolCapabilities tools; - - public Builder experimental(Map experimental) { - this.experimental = experimental; - return this; - } - - public Builder logging() { - this.logging = new LoggingCapabilities(); - return this; - } - - public Builder prompts(Boolean listChanged) { - this.prompts = new PromptCapabilities(listChanged); - return this; - } - - public Builder resources(Boolean subscribe, Boolean listChanged) { - this.resources = new ResourceCapabilities(subscribe, listChanged); - return this; - } - - public Builder tools(Boolean listChanged) { - this.tools = new ToolCapabilities(listChanged); - return this; - } - - public ServerCapabilities build() { - return new ServerCapabilities(experimental, logging, prompts, resources, tools); - } - } - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Implementation(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("version") String version) { - } // @formatter:on - - // Existing Enums and Base Types (from previous implementation) - public enum Role {// @formatter:off - - @JsonProperty("user") USER, - @JsonProperty("assistant") ASSISTANT - }// @formatter:on - - // --------------------------- - // Resource Interfaces - // --------------------------- - /** - * Base for objects that include optional annotations for the client. The client can - * use annotations to inform how objects are used or displayed - */ - public interface Annotated { - - Annotations annotations(); - - } - - /** - * Optional annotations for the client. The client can use annotations to inform how - * objects are used or displayed. - * - * @param audience Describes who the intended customer of this object or data is. It - * can include multiple entries to indicate content useful for multiple audiences - * (e.g., `["user", "assistant"]`). - * @param priority Describes how important this data is for operating the server. A - * value of 1 means "most important," and indicates that the data is effectively - * required, while 0 means "least important," and indicates that the data is entirely - * optional. It is a number between 0 and 1. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Annotations( // @formatter:off - @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority) { - } // @formatter:on - - /** - * A known resource that the server is capable of reading. - * - * @param uri the URI of the resource. - * @param name A human-readable name for this resource. This can be used by clients to - * populate UI elements. - * @param description A description of what this resource represents. This can be used - * by clients to improve the LLM's understanding of available resources. It can be - * thought of like a "hint" to the model. - * @param mimeType The MIME type of this resource, if known. - * @param annotations Optional annotations for the client. The client can use - * annotations to inform how objects are used or displayed. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Resource( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("annotations") Annotations annotations) implements Annotated { - } // @formatter:on - - /** - * Resource templates allow servers to expose parameterized resources using URI - * templates. - * - * @param uriTemplate A URI template that can be used to generate URIs for this - * resource. - * @param name A human-readable name for this resource. This can be used by clients to - * populate UI elements. - * @param description A description of what this resource represents. This can be used - * by clients to improve the LLM's understanding of available resources. It can be - * thought of like a "hint" to the model. - * @param mimeType The MIME type of this resource, if known. - * @param annotations Optional annotations for the client. The client can use - * annotations to inform how objects are used or displayed. - * @see RFC 6570 - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ResourceTemplate( // @formatter:off - @JsonProperty("uriTemplate") String uriTemplate, - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("annotations") Annotations annotations) implements Annotated { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListResourcesResult( // @formatter:off - @JsonProperty("resources") List resources, - @JsonProperty("nextCursor") String nextCursor) { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListResourceTemplatesResult( // @formatter:off - @JsonProperty("resourceTemplates") List resourceTemplates, - @JsonProperty("nextCursor") String nextCursor) { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ReadResourceRequest( // @formatter:off - @JsonProperty("uri") String uri){ - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ReadResourceResult( // @formatter:off - @JsonProperty("contents") List contents){ - } // @formatter:on - - /** - * Sent from the client to request resources/updated notifications from the server - * whenever a particular resource changes. - * - * @param uri the URI of the resource to subscribe to. The URI can use any protocol; - * it is up to the server how to interpret it. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record SubscribeRequest( // @formatter:off - @JsonProperty("uri") String uri){ - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record UnsubscribeRequest( // @formatter:off - @JsonProperty("uri") String uri){ - } // @formatter:on - - /** - * The contents of a specific resource or sub-resource. - */ - @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION, include = As.PROPERTY) - @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class, name = "text"), - @JsonSubTypes.Type(value = BlobResourceContents.class, name = "blob") }) - public sealed interface ResourceContents permits TextResourceContents, BlobResourceContents { - - /** - * The URI of this resource. - * @return the URI of this resource. - */ - String uri(); - - /** - * The MIME type of this resource. - * @return the MIME type of this resource. - */ - String mimeType(); - - } - - /** - * Text contents of a resource. - * - * @param uri the URI of this resource. - * @param mimeType the MIME type of this resource. - * @param text the text of the resource. This must only be set if the resource can - * actually be represented as text (not binary data). - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record TextResourceContents( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("text") String text) implements ResourceContents { - } // @formatter:on - - /** - * Binary contents of a resource. - * - * @param uri the URI of this resource. - * @param mimeType the MIME type of this resource. - * @param blob a base64-encoded string representing the binary data of the resource. - * This must only be set if the resource can actually be represented as binary data - * (not text). - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record BlobResourceContents( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("blob") String blob) implements ResourceContents { - } // @formatter:on - - // --------------------------- - // Prompt Interfaces - // --------------------------- - /** - * A prompt or prompt template that the server offers. - * - * @param name The name of the prompt or prompt template. - * @param description An optional description of what this prompt provides. - * @param arguments A list of arguments to use for templating the prompt. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Prompt( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("arguments") List arguments) { - } // @formatter:on - - /** - * Describes an argument that a prompt can accept. - * - * @param name The name of the argument. - * @param description A human-readable description of the argument. - * @param required Whether this argument must be provided. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PromptArgument( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("required") Boolean required) { - }// @formatter:on - - /** - * Describes a message returned as part of a prompt. - * - * This is similar to `SamplingMessage`, but also supports the embedding of resources - * from the MCP server. - * - * @param role The sender or recipient of messages and data in a conversation. - * @param content The content of the message of type {@link Content}. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PromptMessage( // @formatter:off - @JsonProperty("role") Role role, - @JsonProperty("content") Content content) { - } // @formatter:on - - /** - * The server's response to a prompts/list request from the client. - * - * @param prompts A list of prompts that the server provides. - * @param nextCursor An optional cursor for pagination. If present, indicates there - * are more prompts available. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListPromptsResult( // @formatter:off - @JsonProperty("prompts") List prompts, - @JsonProperty("nextCursor") String nextCursor) { - }// @formatter:on - - /** - * Used by the client to get a prompt provided by the server. - * - * @param name The name of the prompt or prompt template. - * @param arguments Arguments to use for templating the prompt. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record GetPromptRequest(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("arguments") Map arguments) implements Request { - }// @formatter:off - - /** - * The server's response to a prompts/get request from the client. - * - * @param description An optional description for the prompt. - * @param messages A list of messages to display as part of the prompt. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record GetPromptResult( // @formatter:off - @JsonProperty("description") String description, - @JsonProperty("messages") List messages) { - } // @formatter:on - - // --------------------------- - // Tool Interfaces - // --------------------------- - /** - * The server's response to a tools/list request from the client. - * - * @param tools A list of tools that the server provides. - * @param nextCursor An optional cursor for pagination. If present, indicates there - * are more tools available. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListToolsResult( // @formatter:off - @JsonProperty("tools") List tools, - @JsonProperty("nextCursor") String nextCursor) { - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JsonSchema( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("properties") Map properties, - @JsonProperty("required") List required, - @JsonProperty("additionalProperties") Boolean additionalProperties) { - } // @formatter:on - - /** - * Represents a tool that the server provides. Tools enable servers to expose - * executable functionality to the system. Through these tools, you can interact with - * external systems, perform computations, and take actions in the real world. - * - * @param name A unique identifier for the tool. This name is used when calling the - * tool. - * @param description A human-readable description of what the tool does. This can be - * used by clients to improve the LLM's understanding of available tools. - * @param inputSchema A JSON Schema object that describes the expected structure of - * the arguments when calling this tool. This allows clients to validate tool - * arguments before sending them to the server. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Tool( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("inputSchema") JsonSchema inputSchema) { - - public Tool(String name, String description, String schema) { - this(name, description, parseSchema(schema)); - } - - } // @formatter:on - - private static JsonSchema parseSchema(String schema) { - try { - return OBJECT_MAPPER.readValue(schema, JsonSchema.class); - } - catch (IOException e) { - throw new IllegalArgumentException("Invalid schema: " + schema, e); - } - } - - /** - * Used by the client to call a tool provided by the server. - * - * @param name The name of the tool to call. This must match a tool name from - * tools/list. - * @param arguments Arguments to pass to the tool. These must conform to the tool's - * input schema. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CallToolRequest(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("arguments") Map arguments) implements Request { - }// @formatter:off - - /** - * The server's response to a tools/call request from the client. - * - * @param content A list of content items representing the tool's output. Each item can be text, an image, - * or an embedded resource. - * @param isError If true, indicates that the tool execution failed and the content contains error information. - * If false or absent, indicates successful execution. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CallToolResult( // @formatter:off - @JsonProperty("content") List content, - @JsonProperty("isError") Boolean isError) { - } // @formatter:on - - // --------------------------- - // Sampling Interfaces - // --------------------------- - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ModelPreferences(// @formatter:off - @JsonProperty("hints") List hints, - @JsonProperty("costPriority") Double costPriority, - @JsonProperty("speedPriority") Double speedPriority, - @JsonProperty("intelligencePriority") Double intelligencePriority) { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ModelHint(@JsonProperty("name") String name) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record SamplingMessage(// @formatter:off - @JsonProperty("role") Role role, - @JsonProperty("content") Content content) { - } // @formatter:on - - // Sampling and Message Creation - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CreateMessageRequest(// @formatter:off - @JsonProperty("messages") List messages, - @JsonProperty("modelPreferences") ModelPreferences modelPreferences, - @JsonProperty("systemPrompt") String systemPrompt, - @JsonProperty("includeContext") ContextInclusionStrategy includeContext, - @JsonProperty("temperature") Double temperature, - @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, - @JsonProperty("metadata") Map metadata) implements Request { - - public enum ContextInclusionStrategy { - @JsonProperty("none") NONE, - @JsonProperty("this_server") THIS_SERVER, - @JsonProperty("all_server") ALL_SERVERS - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CreateMessageResult(// @formatter:off - @JsonProperty("role") Role role, - @JsonProperty("content") Content content, - @JsonProperty("model") String model, - @JsonProperty("stopReason") StopReason stopReason) { - - public enum StopReason { - @JsonProperty("end_turn") END_TURN, - @JsonProperty("stop_sequence") STOP_SEQUENCE, - @JsonProperty("max_tokens") MAX_TOKENS - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private Role role = Role.ASSISTANT; - private Content content; - private String model; - private StopReason stopReason = StopReason.END_TURN; - - public Builder role(Role role) { - this.role = role; - return this; - } - - public Builder content(Content content) { - this.content = content; - return this; - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder stopReason(StopReason stopReason) { - this.stopReason = stopReason; - return this; - } - - public Builder message(String message) { - this.content = new TextContent(message); - return this; - } - - public CreateMessageResult build() { - return new CreateMessageResult(role, content, model, stopReason); - } - } - }// @formatter:on - - // --------------------------- - // Pagination Interfaces - // --------------------------- - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PaginatedRequest(@JsonProperty("cursor") String cursor) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { - } - - // --------------------------- - // Progress and Logging - // --------------------------- - @JsonIgnoreProperties(ignoreUnknown = true) - public record ProgressNotification(// @formatter:off - @JsonProperty("progressToken") String progressToken, - @JsonProperty("progress") double progress, - @JsonProperty("total") Double total) { - }// @formatter:on - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to send - * structured log messages to clients. Clients can control logging verbosity by - * setting minimum log levels, with servers sending notifications containing severity - * levels, optional logger names, and arbitrary JSON-serializable data. - * - * @param level The severity levels. The mimimum log level is set by the client. - * @param logger The logger that generated the message. - * @param data JSON-serializable logging data. - */ - @JsonIgnoreProperties(ignoreUnknown = true) - public record LoggingMessageNotification(// @formatter:off - @JsonProperty("level") LoggingLevel level, - @JsonProperty("logger") String logger, - @JsonProperty("data") String data) { - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private LoggingLevel level = LoggingLevel.INFO; - private String logger = "server"; - private String data; - - public Builder level(LoggingLevel level) { - this.level = level; - return this; - } - - public Builder logger(String logger) { - this.logger = logger; - return this; - } - - public Builder data(String data) { - this.data = data; - return this; - } - - public LoggingMessageNotification build() { - return new LoggingMessageNotification(level, logger, data); - } - } - }// @formatter:on - - public enum LoggingLevel {// @formatter:off - @JsonProperty("debug") DEBUG(0), - @JsonProperty("info") INFO(1), - @JsonProperty("notice") NOTICE(2), - @JsonProperty("warning") WARNING(3), - @JsonProperty("error") ERROR(4), - @JsonProperty("critical") CRITICAL(5), - @JsonProperty("alert") ALERT(6), - @JsonProperty("emergency") EMERGENCY(7); - - private final int level; - - LoggingLevel(int level) { - this.level = level; - } - - public int level() { - return level; - } - - } // @formatter:on - - // --------------------------- - // Autocomplete - // --------------------------- - public record CompleteRequest(PromptOrResourceReference ref, CompleteArgument argument) implements Request { - public sealed interface PromptOrResourceReference permits PromptReference, ResourceReference { - - String type(); - - } - - public record PromptReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("name") String name) implements PromptOrResourceReference { - }// @formatter:on - - public record ResourceReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("uri") String uri) implements PromptOrResourceReference { - }// @formatter:on - - public record CompleteArgument(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("value") String value) { - }// @formatter:on - } - - public record CompleteResult(CompleteCompletion completion) { - public record CompleteCompletion(// @formatter:off - @JsonProperty("values") List values, - @JsonProperty("total") Integer total, - @JsonProperty("hasMore") Boolean hasMore) { - }// @formatter:on - } - - // --------------------------- - // Content Types - // --------------------------- - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") - @JsonSubTypes({ @JsonSubTypes.Type(value = TextContent.class, name = "text"), - @JsonSubTypes.Type(value = ImageContent.class, name = "image"), - @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource") }) - public sealed interface Content permits TextContent, ImageContent, EmbeddedResource { - - default String type() { - if (this instanceof TextContent) { - return "text"; - } - else if (this instanceof ImageContent) { - return "image"; - } - else if (this instanceof EmbeddedResource) { - return "resource"; - } - throw new IllegalArgumentException("Unknown content type: " + this); - } - - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record TextContent( // @formatter:off - @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority, - @JsonProperty("text") String text) implements Content { // @formatter:on - - public TextContent(String content) { - this(null, null, content); - } - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ImageContent( // @formatter:off - @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority, - @JsonProperty("data") String data, - @JsonProperty("mimeType") String mimeType) implements Content { // @formatter:on - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record EmbeddedResource( // @formatter:off - @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority, - @JsonProperty("resource") ResourceContents resource) implements Content { // @formatter:on - } - - // --------------------------- - // Roots - // --------------------------- - /** - * Represents a root directory or file that the server can operate on. - * - * @param uri The URI identifying the root. This *must* start with file:// for now. - * This restriction may be relaxed in future versions of the protocol to allow other - * URI schemes. - * @param name An optional name for the root. This can be used to provide a - * human-readable identifier for the root, which may be useful for display purposes or - * for referencing the root in other parts of the application. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Root( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("name") String name) { - } // @formatter:on - - /** - * The client's response to a roots/list request from the server. This result contains - * an array of Root objects, each representing a root directory or file that the - * server can operate on. - * - * @param roots An array of Root objects, each representing a root directory or file - * that the server can operate on. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListRootsResult( // @formatter:off - @JsonProperty("roots") List roots) { - } // @formatter:on - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java deleted file mode 100644 index 135914322..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java +++ /dev/null @@ -1,13 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -/** - * Marker interface for the server-side MCP transport. - * - * @author Christian Tzolov - */ -public interface ServerMcpTransport extends McpTransport { - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java deleted file mode 100644 index 0f799ca0f..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.util; - -import java.util.Collection; -import java.util.Map; - -import reactor.util.annotation.Nullable; - -/** - * Miscellaneous utility methods. - * - * @author Christian Tzolov - */ - -public final class Utils { - - /** - * Check whether the given {@code String} contains actual text. - *

    - * More specifically, this method returns {@code true} if the {@code String} is not - * {@code null}, its length is greater than 0, and it contains at least one - * non-whitespace character. - * @param str the {@code String} to check (may be {@code null}) - * @return {@code true} if the {@code String} is not {@code null}, its length is - * greater than 0, and it does not contain whitespace only - * @see Character#isWhitespace - */ - public static boolean hasText(@Nullable String str) { - return (str != null && !str.isBlank()); - } - - /** - * Return {@code true} if the supplied Collection is {@code null} or empty. Otherwise, - * return {@code false}. - * @param collection the Collection to check - * @return whether the given Collection is empty - */ - public static boolean isEmpty(@Nullable Collection collection) { - return (collection == null || collection.isEmpty()); - } - - /** - * Return {@code true} if the supplied Map is {@code null} or empty. Otherwise, return - * {@code false}. - * @param map the Map to check - * @return whether the given Map is empty - */ - public static boolean isEmpty(@Nullable Map map) { - return (map == null || map.isEmpty()); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java deleted file mode 100644 index 661c629ea..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; -import io.modelcontextprotocol.spec.McpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncClient} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncClientTests { - - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - - abstract protected ClientMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); - } - - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); - - assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(getTimeoutDuration()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); - }).doesNotThrowAnyException(); - } - - @AfterEach - void tearDown() { - if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - onClose(); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - } - - @Test - void testListTools() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); - } - - @Test - void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); - } - - @Test - void testPing() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); - } - - @Test - void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); - } - - @Test - void testCallTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); - } - - @Test - void testCallToolWithInvalidTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); - } - - @Test - void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); - } - - @Test - void testListResources() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); - } - - @Test - void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); - } - - @Test - void testListPromptsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts"); - } - - @Test - void testListPrompts() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); - } - - @Test - void testGetPromptWithoutInitialization() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts"); - } - - @Test - void testGetPrompt() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); - } - - @Test - void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); - } - - @Test - void testRootsListChanged() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); - } - - @Test - void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); - } - - @Test - void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - } - - @Test - @Disabled - void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); - } - - @Test - void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); - } - - @Test - void testListResourceTemplates() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); - } - - // @Test - void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); - - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) - .build(); - - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) - .roots(true) - .sampling() - .build(); - - Function> samplingHandler = request -> Mono - .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block()) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); - } - - @Test - void testLoggingLevels() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java deleted file mode 100644 index 6f8cf198e..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; - -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.ListResourceTemplatesResult; -import io.modelcontextprotocol.spec.McpSchema.ListResourcesResult; -import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; -import io.modelcontextprotocol.spec.McpSchema.TextContent; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for MCP Client Session functionality. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpSyncClientTests { - - private McpSyncClient mcpSyncClient; - - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - - protected ClientMcpTransport mcpTransport; - - abstract protected ClientMcpTransport createMcpTransport(); - - abstract protected void onStart(); - - abstract protected void onClose(); - - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); - } - - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); - - assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(getTimeoutDuration()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); - }).doesNotThrowAnyException(); - } - - @AfterEach - void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } - onClose(); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - } - - @Test - void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); - - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }); - } - - @Test - void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); - } - - @Test - void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - - assertThat(toolResult).isNotNull().satisfies(result -> { - - assertThat(result.content()).hasSize(1); - - TextContent content = (TextContent) result.content().get(0); - - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); - }); - } - - @Test - void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); - } - - @Test - void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); - } - - @Test - void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); - } - - @Test - void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - } - - @Test - void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); - } - - @Test - void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); - } - - @Test - void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); - } - - @Test - void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); - } - - @Test - void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - } - - @Test - void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); - } - - @Test - void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - @Test - void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); - } - - @Test - void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); - } - - @Test - void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - } - - @Test - void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); - } - - @Test - void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } - } - - @Test - void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); - } - - @Test - void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - } - - // @Test - void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); - } - - @Test - void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java deleted file mode 100644 index 7cc673fa1..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.time.Duration; - -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -/** - * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - - String host = "http://localhost:3004"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected ClientMcpTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - protected void onClose() { - container.stop(); - } - - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java deleted file mode 100644 index ce74812b7..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import org.junit.jupiter.api.Timeout; - -/** - * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - - @Override - protected ClientMcpTransport createMcpTransport() { - ServerParameters stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); - return new StdioClientTransport(stdioParams); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java deleted file mode 100644 index 7ae65253a..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.util.concurrent.atomic.AtomicReference; - -import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { - - @Override - protected ClientMcpTransport createMcpTransport() { - ServerParameters stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); - - return new StdioClientTransport(stdioParams); - } - - @Test - void customErrorHandlerShouldReceiveErrors() { - AtomicReference receivedError = new AtomicReference<>(); - - ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); - - String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().tryEmitNext(errorMessage); - - assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java deleted file mode 100644 index 294056fbe..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; - -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -import org.springframework.http.codec.ServerSentEvent; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; - -/** - * Tests for the {@link HttpClientSseClientTransport} class. - * - * @author Christian Tzolov - */ -@Timeout(15) -class HttpClientSseClientTransportTests { - - static String host = "http://localhost:3001"; - - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - private TestHttpClientSseClientTransport transport; - - // Test class to access protected methods - static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { - - private final AtomicInteger inboundMessageCount = new AtomicInteger(0); - - private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - - public TestHttpClientSseClientTransport(String baseUri) { - super(baseUri); - } - - public int getInboundMessageCount() { - return inboundMessageCount.get(); - } - - public void simulateEndpointEvent(String jsonMessage) { - events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); - inboundMessageCount.incrementAndGet(); - } - - public void simulateMessageEvent(String jsonMessage) { - events.tryEmitNext(ServerSentEvent.builder().event("message").data(jsonMessage).build()); - inboundMessageCount.incrementAndGet(); - } - - } - - void startContainer() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @BeforeEach - void setUp() { - startContainer(); - transport = new TestHttpClientSseClientTransport(host); - transport.connect(Function.identity()).block(); - } - - @AfterEach - void afterEach() { - if (transport != null) { - assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - cleanup(); - } - - void cleanup() { - container.stop(); - } - - @Test - void testMessageProcessing() { - // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Simulate receiving the message - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "test-method", - "id": "test-id", - "params": {"key": "value"} - } - """); - - // Subscribe to messages and verify - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testResponseMessageProcessing() { - // Simulate receiving a response message - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "id": "test-id", - "result": {"status": "success"} - } - """); - - // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Verify message handling - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testErrorMessageProcessing() { - // Simulate receiving an error message - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "id": "test-id", - "error": { - "code": -32600, - "message": "Invalid Request" - } - } - """); - - // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Verify message handling - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testNotificationMessageProcessing() { - // Simulate receiving a notification message (no id) - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "update", - "params": {"status": "processing"} - } - """); - - // Verify the notification was processed - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testGracefulShutdown() { - // Test graceful shutdown - StepVerifier.create(transport.closeGracefully()).verifyComplete(); - - // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Verify message is not processed after shutdown - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - // Message count should remain 0 after shutdown - assertThat(transport.getInboundMessageCount()).isEqualTo(0); - } - - @Test - void testRetryBehavior() { - // Create a client that simulates connection failures - HttpClientSseClientTransport failingTransport = new HttpClientSseClientTransport("http://non-existent-host"); - - // Verify that the transport attempts to reconnect - StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); - - // Clean up - failingTransport.closeGracefully().block(); - } - - @Test - void testMultipleMessageProcessing() { - // Simulate receiving multiple messages in sequence - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "method1", - "id": "id1", - "params": {"key": "value1"} - } - """); - - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "method2", - "id": "id2", - "params": {"key": "value2"} - } - """); - - // Create and send corresponding messages - JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", - Map.of("key", "value1")); - - JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", - Map.of("key", "value2")); - - // Verify both messages are processed - StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete(); - - // Verify message count - assertThat(transport.getInboundMessageCount()).isEqualTo(2); - } - - @Test - void testMessageOrderPreservation() { - // Simulate receiving messages in a specific order - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "first", - "id": "1", - "params": {"sequence": 1} - } - """); - - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "second", - "id": "2", - "params": {"sequence": 2} - } - """); - - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "third", - "id": "3", - "params": {"sequence": 3} - } - """); - - // Verify message count and order - assertThat(transport.getInboundMessageCount()).isEqualTo(3); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java deleted file mode 100644 index dcc103b54..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ /dev/null @@ -1,464 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncServerTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected ServerMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); - } - - @Test - void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - @Test - void testRemoveResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); - }); - } - - @Test - void testAddPromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePrompt() { - String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) - .build(); - - StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); - - assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeConsumers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - }))) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java deleted file mode 100644 index bdcd7ae3a..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ /dev/null @@ -1,432 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpSyncServerTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected ServerMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - // onStart(); - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); - } - - @Test - void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) - .build(); - - assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); - - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Resource must not be null"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); - - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - @Test - void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); - } - - @Test - void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) - .build(); - - assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeConsumers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - })) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java deleted file mode 100644 index 0ab72a99f..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ /dev/null @@ -1,69 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.io.InputStream; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -public class BlockingInputStream extends InputStream { - - private final BlockingQueue queue = new LinkedBlockingQueue<>(); - - private volatile boolean completed = false; - - private volatile boolean closed = false; - - @Override - public int read() throws IOException { - if (closed) { - throw new IOException("Stream is closed"); - } - - try { - Integer value = queue.poll(); - if (value == null) { - if (completed) { - return -1; - } - value = queue.take(); // Blocks until data is available - if (value == null && completed) { - return -1; - } - } - return value; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Read interrupted", e); - } - } - - public void write(int b) { - if (!closed && !completed) { - queue.offer(b); - } - } - - public void write(byte[] data) { - if (!closed && !completed) { - for (byte b : data) { - queue.offer((int) b & 0xFF); - } - } - } - - public void complete() { - this.completed = true; - } - - @Override - public void close() { - this.closed = true; - this.completed = true; - this.queue.clear(); - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java deleted file mode 100644 index 4a292da31..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server.transport; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.test.StepVerifier; - -import org.springframework.web.client.RestClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.awaitility.Awaitility.await; - -public class HttpServletSseServerTransportIntegrationTests { - - private static final int PORT = 8184; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private HttpServletSseServerTransport mcpServerTransport; - - McpClient.SyncSpec clientBuilder; - - private Tomcat tomcat; - - @BeforeEach - public void before() { - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - Context context = tomcat.addContext("", baseDir); - - // Create and configure the transport - mcpServerTransport = new HttpServletSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - - // Add transport servlet to Tomcat - org.apache.catalina.Wrapper wrapper = context.createWrapper(); - wrapper.setName("mcpServlet"); - wrapper.setServlet(mcpServerTransport); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addChild(wrapper); - context.addServletMappingDecoded("/*", "mcpServlet"); - - try { - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); - tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); - } - - @AfterEach - public void after() { - if (mcpServerTransport != null) { - mcpServerTransport.closeGracefully().block(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @Test - void testCreateMessageWithoutSamplingCapabilities() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @Test - void testCreateMessageSuccess() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - } - - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - mcpClient.close(); - mcpServer.close(); - } - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - AtomicReference> toolsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - toolsRef.set(toolsUpdate); - }).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(toolsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); - }); - - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).isEmpty(); - }); - - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransport).build(); - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.close(); - mcpServer.close(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java deleted file mode 100644 index 43e5019fc..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.io.PrintStream; -import java.nio.charset.StandardCharsets; -import java.util.Map; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests for {@link StdioServerTransport}. - * - * @author Christian Tzolov - */ -class StdioServerTransportTests { - - private final InputStream originalIn = System.in; - - private final PrintStream originalOut = System.out; - - private final PrintStream originalErr = System.err; - - private ByteArrayOutputStream testOut; - - private ByteArrayOutputStream testErr; - - private PrintStream testOutPrintStream; - - private StdioServerTransport transport; - - private ObjectMapper objectMapper; - - @BeforeEach - void setUp() { - testOut = new ByteArrayOutputStream(); - testErr = new ByteArrayOutputStream(); - testOutPrintStream = new PrintStream(testOut, true); - System.setOut(testOutPrintStream); - System.setErr(new PrintStream(testErr)); - - objectMapper = new ObjectMapper(); - } - - @AfterEach - void tearDown() { - if (transport != null) { - transport.closeGracefully().block(); - } - if (testOutPrintStream != null) { - testOutPrintStream.close(); - } - System.setIn(originalIn); - System.setOut(originalOut); - System.setErr(originalErr); - } - - @Test - void shouldHandleIncomingMessages() throws Exception { - // Prepare test input - String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}"; - - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - - // Parse expected message - McpSchema.JSONRPCRequest expected = objectMapper.readValue(jsonMessage, McpSchema.JSONRPCRequest.class); - - // Connect transport with message handler and verify message - StepVerifier.create(transport.connect(message -> message.doOnNext(msg -> { - McpSchema.JSONRPCRequest received = (McpSchema.JSONRPCRequest) msg; - assertThat(received.id()).isEqualTo(expected.id()); - assertThat(received.method()).isEqualTo(expected.method()); - }))).verifyComplete(); - } - - @Test - @Disabled - void shouldHandleOutgoingMessages() throws Exception { - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - // transport = new StdioServerTransport(objectMapper, new BlockingInputStream(), - // testOutPrintStream); - - // Create test messages - JSONRPCRequest initMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "init", "init-id", - Map.of("init", "true")); - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test", "test-id", - Map.of("key", "value")); - - // Connect transport, send messages, and verify output in a reactive chain - StepVerifier.create(transport.connect(message -> message) - .then(transport.sendMessage(initMessage)) - // .then(Mono.fromRunnable(() -> testOut.reset())) // Clear buffer after init - // message - .then(transport.sendMessage(testMessage)) - .then(Mono.fromCallable(() -> { - String output = testOut.toString(StandardCharsets.UTF_8); - assertThat(output).contains("\"jsonrpc\":\"2.0\""); - assertThat(output).contains("\"method\":\"test\""); - assertThat(output).contains("\"id\":\"test-id\""); - return null; - }))).verifyComplete(); - } - - @Test - void shouldWaitForProcessorsBeforeSendingMessage() { - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - - // Create test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test", "test-id", - Map.of("key", "value")); - - // Try to send message before connecting (before processors are ready) - StepVerifier.create(transport.sendMessage(testMessage)).verifyTimeout(java.time.Duration.ofMillis(100)); - - // Connect transport and verify message can be sent - StepVerifier.create(transport.connect(message -> message).then(transport.sendMessage(testMessage))) - .verifyComplete(); - } - - @Test - void shouldCloseGracefully() { - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - - // Create test message - JSONRPCRequest initMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "init", "init-id", - Map.of("init", "true")); - - // Connect transport, send message, and close gracefully in a reactive chain - StepVerifier - .create(transport.connect(message -> message) - .then(transport.sendMessage(initMessage)) - .then(transport.closeGracefully())) - .verifyComplete(); - - // Verify error log is empty - assertThat(testErr.toString()).doesNotContain("Error"); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java deleted file mode 100644 index 9d011afff..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.time.Duration; -import java.util.Map; - -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.MockMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for {@link DefaultMcpSession} that verifies its JSON-RPC message handling, - * request-response correlation, and notification processing. - * - * @author Christian Tzolov - */ -class DefaultMcpSessionTests { - - private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSessionTests.class); - - private static final Duration TIMEOUT = Duration.ofSeconds(5); - - private static final String TEST_METHOD = "test.method"; - - private static final String TEST_NOTIFICATION = "test.notification"; - - private static final String ECHO_METHOD = "echo"; - - private DefaultMcpSession session; - - private MockMcpTransport transport; - - @BeforeEach - void setUp() { - transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); - } - - @AfterEach - void tearDown() { - if (session != null) { - session.close(); - } - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new DefaultMcpSession(null, transport, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("requstTimeout can not be null"); - - assertThatThrownBy(() -> new DefaultMcpSession(TIMEOUT, null, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("transport can not be null"); - } - - TypeReference responseType = new TypeReference<>() { - }; - - @Test - void testSendRequest() { - String testParam = "test parameter"; - String responseData = "test response"; - - // Create a Mono that will emit the response after the request is sent - Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); - // Verify response handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); - }).consumeNextWith(response -> { - // Verify the request was sent - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); - McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; - assertThat(request.method()).isEqualTo(TEST_METHOD); - assertThat(request.params()).isEqualTo(testParam); - assertThat(response).isEqualTo(responseData); - }).verifyComplete(); - } - - @Test - void testSendRequestWithError() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify error handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - // Simulate error response - McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); - }).expectError(McpError.class).verify(); - } - - @Test - void testRequestTimeout() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify timeout - StepVerifier.create(responseMono) - .expectError(java.util.concurrent.TimeoutException.class) - .verify(TIMEOUT.plusSeconds(1)); - } - - @Test - void testSendNotification() { - Map params = Map.of("key", "value"); - Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); - - // Verify notification was sent - StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); - McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; - assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); - assertThat(notification.params()).isEqualTo(params); - }).verifyComplete(); - } - - @Test - void testRequestHandling() { - String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, - params -> Mono.just(params)); - transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, requestHandlers, Map.of()); - - // Simulate incoming request - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, - "test-id", echoMessage); - transport.simulateIncomingMessage(request); - - // Verify response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.result()).isEqualTo(echoMessage); - assertThat(response.error()).isNull(); - } - - @Test - void testNotificationHandling() { - Sinks.One receivedParams = Sinks.one(); - - transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - - // Simulate incoming notification from the server - Map notificationParams = Map.of("status", "ready"); - - McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - TEST_NOTIFICATION, notificationParams); - - transport.simulateIncomingMessage(notification); - - // Verify handler was called - assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); - } - - @Test - void testUnknownMethodHandling() { - // Simulate incoming request for unknown method - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", - "test-id", null); - transport.simulateIncomingMessage(request); - - // Verify error response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.error()).isNotNull(); - assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); - } - - @Test - void testGracefulShutdown() { - StepVerifier.create(session.closeGracefully()).verifyComplete(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java deleted file mode 100644 index 05e2ce28c..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ /dev/null @@ -1,592 +0,0 @@ -/* -* Copyright 2025 - 2025 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; -import net.javacrumbs.jsonunit.core.Option; -import org.junit.jupiter.api.Test; - -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * @author Christian Tzolov - */ -public class McpSchemaTests { - - ObjectMapper mapper = new ObjectMapper(); - - // Content Types Tests - - @Test - void testTextContent() throws Exception { - McpSchema.TextContent test = new McpSchema.TextContent("XXX"); - String value = mapper.writeValueAsString(test); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"type":"text","text":"XXX"}""")); - } - - @Test - void testTextContentDeserialization() throws Exception { - McpSchema.TextContent textContent = mapper.readValue(""" - {"type":"text","text":"XXX"}""", McpSchema.TextContent.class); - - assertThat(textContent).isNotNull(); - assertThat(textContent.type()).isEqualTo("text"); - assertThat(textContent.text()).isEqualTo("XXX"); - } - - @Test - void testContentDeserializationWrongType() throws Exception { - - assertThatThrownBy(() -> mapper.readValue(""" - {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) - .isInstanceOf(InvalidTypeIdException.class) - .hasMessageContaining( - "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [image, resource, text]"); - } - - @Test - void testImageContent() throws Exception { - McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); - String value = mapper.writeValueAsString(test); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""")); - } - - @Test - void testImageContentDeserialization() throws Exception { - McpSchema.ImageContent imageContent = mapper.readValue(""" - {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""", McpSchema.ImageContent.class); - assertThat(imageContent).isNotNull(); - assertThat(imageContent.type()).isEqualTo("image"); - assertThat(imageContent.data()).isEqualTo("base64encodeddata"); - assertThat(imageContent.mimeType()).isEqualTo("image/png"); - } - - @Test - void testEmbeddedResource() throws Exception { - McpSchema.TextResourceContents resourceContents = new McpSchema.TextResourceContents("resource://test", - "text/plain", "Sample resource content"); - - McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); - - String value = mapper.writeValueAsString(test); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""")); - } - - @Test - void testEmbeddedResourceDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( - """ - {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""", - McpSchema.EmbeddedResource.class); - assertThat(embeddedResource).isNotNull(); - assertThat(embeddedResource.type()).isEqualTo("resource"); - assertThat(embeddedResource.resource()).isNotNull(); - assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); - assertThat(embeddedResource.resource().mimeType()).isEqualTo("text/plain"); - assertThat(((TextResourceContents) embeddedResource.resource()).text()).isEqualTo("Sample resource content"); - } - - @Test - void testEmbeddedResourceWithBlobContents() throws Exception { - McpSchema.BlobResourceContents resourceContents = new McpSchema.BlobResourceContents("resource://test", - "application/octet-stream", "base64encodedblob"); - - McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); - - String value = mapper.writeValueAsString(test); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""")); - } - - @Test - void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( - """ - {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""", - McpSchema.EmbeddedResource.class); - assertThat(embeddedResource).isNotNull(); - assertThat(embeddedResource.type()).isEqualTo("resource"); - assertThat(embeddedResource.resource()).isNotNull(); - assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); - assertThat(embeddedResource.resource().mimeType()).isEqualTo("application/octet-stream"); - assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).blob()) - .isEqualTo("base64encodedblob"); - } - - // JSON-RPC Message Types Tests - - @Test - void testJSONRPCRequest() throws Exception { - Map params = new HashMap<>(); - params.put("key", "value"); - - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, - params); - - String value = mapper.writeValueAsString(request); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"}}""")); - } - - @Test - void testJSONRPCNotification() throws Exception { - Map params = new HashMap<>(); - params.put("key", "value"); - - McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - "notification_method", params); - - String value = mapper.writeValueAsString(notification); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","method":"notification_method","params":{"key":"value"}}""")); - } - - @Test - void testJSONRPCResponse() throws Exception { - Map result = new HashMap<>(); - result.put("result_key", "result_value"); - - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); - - String value = mapper.writeValueAsString(response); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","id":1,"result":{"result_key":"result_value"}}""")); - } - - @Test - void testJSONRPCResponseWithError() throws Exception { - McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); - - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); - - String value = mapper.writeValueAsString(response); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid request"}}""")); - } - - // Initialization Tests - - @Test - void testInitializeRequest() throws Exception { - McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() - .roots(true) - .sampling() - .build(); - - McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - - McpSchema.InitializeRequest request = new McpSchema.InitializeRequest("2024-11-05", capabilities, clientInfo); - - String value = mapper.writeValueAsString(request); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"protocolVersion":"2024-11-05","capabilities":{"roots":{"listChanged":true},"sampling":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}}""")); - } - - @Test - void testInitializeResult() throws Exception { - McpSchema.ServerCapabilities capabilities = McpSchema.ServerCapabilities.builder() - .logging() - .prompts(true) - .resources(true, true) - .tools(true) - .build(); - - McpSchema.Implementation serverInfo = new McpSchema.Implementation("test-server", "1.0.0"); - - McpSchema.InitializeResult result = new McpSchema.InitializeResult("2024-11-05", capabilities, serverInfo, - "Server initialized successfully"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{"listChanged":true},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"test-server","version":"1.0.0"},"instructions":"Server initialized successfully"}""")); - } - - // Resource Tests - - @Test - void testResource() throws Exception { - McpSchema.Annotations annotations = new McpSchema.Annotations( - Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); - - McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", - "text/plain", annotations); - - String value = mapper.writeValueAsString(resource); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","annotations":{"audience":["user","assistant"],"priority":0.8}}""")); - } - - @Test - void testResourceTemplate() throws Exception { - McpSchema.Annotations annotations = new McpSchema.Annotations(Arrays.asList(McpSchema.Role.USER), 0.5); - - McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", - "A test resource template", "text/plain", annotations); - - String value = mapper.writeValueAsString(template); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"uriTemplate":"resource://{param}/test","name":"Test Template","description":"A test resource template","mimeType":"text/plain","annotations":{"audience":["user"],"priority":0.5}}""")); - } - - @Test - void testListResourcesResult() throws Exception { - McpSchema.Resource resource1 = new McpSchema.Resource("resource://test1", "Test Resource 1", - "First test resource", "text/plain", null); - - McpSchema.Resource resource2 = new McpSchema.Resource("resource://test2", "Test Resource 2", - "Second test resource", "application/json", null); - - McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), - "next-cursor"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"resources":[{"uri":"resource://test1","name":"Test Resource 1","description":"First test resource","mimeType":"text/plain"},{"uri":"resource://test2","name":"Test Resource 2","description":"Second test resource","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); - } - - @Test - void testListResourceTemplatesResult() throws Exception { - McpSchema.ResourceTemplate template1 = new McpSchema.ResourceTemplate("resource://{param}/test1", - "Test Template 1", "First test template", "text/plain", null); - - McpSchema.ResourceTemplate template2 = new McpSchema.ResourceTemplate("resource://{param}/test2", - "Test Template 2", "Second test template", "application/json", null); - - McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( - Arrays.asList(template1, template2), "next-cursor"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"resourceTemplates":[{"uriTemplate":"resource://{param}/test1","name":"Test Template 1","description":"First test template","mimeType":"text/plain"},{"uriTemplate":"resource://{param}/test2","name":"Test Template 2","description":"Second test template","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); - } - - @Test - void testReadResourceRequest() throws Exception { - McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test"); - - String value = mapper.writeValueAsString(request); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"uri":"resource://test"}""")); - } - - @Test - void testReadResourceResult() throws Exception { - McpSchema.TextResourceContents contents1 = new McpSchema.TextResourceContents("resource://test1", "text/plain", - "Sample text content"); - - McpSchema.BlobResourceContents contents2 = new McpSchema.BlobResourceContents("resource://test2", - "application/octet-stream", "base64encodedblob"); - - McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2)); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"contents":[{"uri":"resource://test1","mimeType":"text/plain","text":"Sample text content"},{"uri":"resource://test2","mimeType":"application/octet-stream","blob":"base64encodedblob"}]}""")); - } - - // Prompt Tests - - @Test - void testPrompt() throws Exception { - McpSchema.PromptArgument arg1 = new McpSchema.PromptArgument("arg1", "First argument", true); - - McpSchema.PromptArgument arg2 = new McpSchema.PromptArgument("arg2", "Second argument", false); - - McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "A test prompt", Arrays.asList(arg1, arg2)); - - String value = mapper.writeValueAsString(prompt); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"name":"test-prompt","description":"A test prompt","arguments":[{"name":"arg1","description":"First argument","required":true},{"name":"arg2","description":"Second argument","required":false}]}""")); - } - - @Test - void testPromptMessage() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("Hello, world!"); - - McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); - - String value = mapper.writeValueAsString(message); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"role":"user","content":{"type":"text","text":"Hello, world!"}}""")); - } - - @Test - void testListPromptsResult() throws Exception { - McpSchema.PromptArgument arg = new McpSchema.PromptArgument("arg", "An argument", true); - - McpSchema.Prompt prompt1 = new McpSchema.Prompt("prompt1", "First prompt", Collections.singletonList(arg)); - - McpSchema.Prompt prompt2 = new McpSchema.Prompt("prompt2", "Second prompt", Collections.emptyList()); - - McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), - "next-cursor"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"prompts":[{"name":"prompt1","description":"First prompt","arguments":[{"name":"arg","description":"An argument","required":true}]},{"name":"prompt2","description":"Second prompt","arguments":[]}],"nextCursor":"next-cursor"}""")); - } - - @Test - void testGetPromptRequest() throws Exception { - Map arguments = new HashMap<>(); - arguments.put("arg1", "value1"); - arguments.put("arg2", 42); - - McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); - - assertThat(mapper.readValue(""" - {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) - .isEqualTo(request); - } - - @Test - void testGetPromptResult() throws Exception { - McpSchema.TextContent content1 = new McpSchema.TextContent("System message"); - McpSchema.TextContent content2 = new McpSchema.TextContent("User message"); - - McpSchema.PromptMessage message1 = new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, content1); - - McpSchema.PromptMessage message2 = new McpSchema.PromptMessage(McpSchema.Role.USER, content2); - - McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", - Arrays.asList(message1, message2)); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"description":"A test prompt result","messages":[{"role":"assistant","content":{"type":"text","text":"System message"}},{"role":"user","content":{"type":"text","text":"User message"}}]}""")); - } - - // Tool Tests - - @Test - void testTool() throws Exception { - String schemaJson = """ - { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "value": { - "type": "number" - } - }, - "required": ["name"] - } - """; - - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson); - - String value = mapper.writeValueAsString(tool); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); - } - - @Test - void testCallToolRequest() throws Exception { - Map arguments = new HashMap<>(); - arguments.put("name", "test"); - arguments.put("value", 42); - - McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); - - String value = mapper.writeValueAsString(request); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"name":"test-tool","arguments":{"name":"test","value":42}}""")); - } - - @Test - void testCallToolResult() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); - - McpSchema.CallToolResult result = new McpSchema.CallToolResult(Collections.singletonList(content), false); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); - } - - // Sampling Tests - - @Test - void testCreateMessageRequest() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("User message"); - - McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); - - McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); - - McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, - 0.7, 0.9); - - Map metadata = new HashMap<>(); - metadata.put("session", "test-session"); - - McpSchema.CreateMessageRequest request = new McpSchema.CreateMessageRequest(Collections.singletonList(message), - preferences, "You are a helpful assistant", - McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER, 0.7, 1000, - Arrays.asList("STOP", "END"), metadata); - - String value = mapper.writeValueAsString(request); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"messages":[{"role":"user","content":{"type":"text","text":"User message"}}],"modelPreferences":{"hints":[{"name":"gpt-4"}],"costPriority":0.3,"speedPriority":0.7,"intelligencePriority":0.9},"systemPrompt":"You are a helpful assistant","includeContext":"this_server","temperature":0.7,"maxTokens":1000,"stopSequences":["STOP","END"],"metadata":{"session":"test-session"}}""")); - } - - @Test - void testCreateMessageResult() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("Assistant response"); - - McpSchema.CreateMessageResult result = new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, content, - "gpt-4", McpSchema.CreateMessageResult.StopReason.END_TURN); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"end_turn"}""")); - } - - // Roots Tests - - @Test - void testRoot() throws Exception { - McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root"); - - String value = mapper.writeValueAsString(root); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"uri":"file:///path/to/root","name":"Test Root"}""")); - } - - @Test - void testListRootsResult() throws Exception { - McpSchema.Root root1 = new McpSchema.Root("file:///path/to/root1", "First Root"); - - McpSchema.Root root2 = new McpSchema.Root("file:///path/to/root2", "Second Root"); - - McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2)); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"roots":[{"uri":"file:///path/to/root1","name":"First Root"},{"uri":"file:///path/to/root2","name":"Second Root"}]}""")); - - } - -} diff --git a/migration-0.8.0.md b/migration-0.8.0.md new file mode 100644 index 000000000..3ba29a10b --- /dev/null +++ b/migration-0.8.0.md @@ -0,0 +1,328 @@ +# MCP Java SDK Migration Guide: 0.7.0 to 0.8.0 + +This document outlines the breaking changes and provides guidance on how to migrate your code from version 0.7.0 to 0.8.0. + +The 0.8.0 refactoring introduces a session-based architecture for server-side MCP implementations. +It improves the SDK's ability to handle multiple concurrent client connections and provides an API better aligned with the MCP specification. +The main changes include: + +1. Introduction of a session-based architecture +2. New transport provider abstraction +3. Exchange objects for client interaction +4. Renamed and reorganized interfaces +5. Updated handler signatures + +## Breaking Changes + +### 1. Interface Renaming + +Several interfaces have been renamed to better reflect their roles: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ClientMcpTransport` | `McpClientTransport` | +| `ServerMcpTransport` | `McpServerTransport` | +| `DefaultMcpSession` | `McpClientSession`, `McpServerSession` | + +### 2. New Server Transport Architecture + +The most significant change is the introduction of the `McpServerTransportProvider` interface, which replaces direct usage of `ServerMcpTransport` when creating servers. This new pattern separates the concerns of: + +1. **Transport Provider**: Manages connections with clients and creates individual transports for each connection +2. **Server Transport**: Handles communication with a specific client connection + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ServerMcpTransport` | `McpServerTransportProvider` + `McpServerTransport` | +| Direct transport usage | Session-based transport usage | + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +### 3. Handler Method Signature Changes + +Tool, resource, and prompt handlers now receive an additional `exchange` parameter that provides access to client capabilities and methods to interact with the client: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `(args) -> result` | `(exchange, args) -> result` | + +The exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide context for the current session and access to session-specific operations. + +#### Before (0.7.0): + +```java +// Tool handler +.tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, req -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, req -> new GetPromptResult("Analysis prompt")) +``` + +#### After (0.8.0): + +```java +// Tool handler +.tool(calculatorTool, (exchange, args) -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, (exchange, req) -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, (exchange, req) -> new GetPromptResult("Analysis prompt")) +``` + +### 4. Registration vs. Specification + +The naming convention for handlers has changed from "Registration" to "Specification": + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `AsyncToolRegistration` | `AsyncToolSpecification` | +| `SyncToolRegistration` | `SyncToolSpecification` | +| `AsyncResourceRegistration` | `AsyncResourceSpecification` | +| `SyncResourceRegistration` | `SyncResourceSpecification` | +| `AsyncPromptRegistration` | `AsyncPromptSpecification` | +| `SyncPromptRegistration` | `SyncPromptSpecification` | + +### 5. Roots Change Handler Updates + +The roots change handlers now receive an exchange parameter: + +#### Before (0.7.0): + +```java +.rootsChangeConsumers(List.of( + roots -> { + // Process roots + } +)) +``` + +#### After (0.8.0): + +```java +.rootsChangeHandlers(List.of( + (exchange, roots) -> { + // Process roots with access to exchange + } +)) +``` + +### 6. Server Creation Method Changes + +The `McpServer` factory methods now accept `McpServerTransportProvider` instead of `ServerMcpTransport`: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `McpServer.async(ServerMcpTransport)` | `McpServer.async(McpServerTransportProvider)` | +| `McpServer.sync(ServerMcpTransport)` | `McpServer.sync(McpServerTransportProvider)` | + +The method names for creating servers have been updated: + +Root change handlers now receive an exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `rootsChangeConsumers(List>>)` | `rootsChangeHandlers(List>>)` | +| `rootsChangeConsumer(Consumer>)` | `rootsChangeHandler(BiConsumer>)` | + +### 7. Direct Server Methods Moving to Exchange + +Several methods that were previously available directly on the server are now accessed through the exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `server.listRoots()` | `exchange.listRoots()` | +| `server.createMessage()` | `exchange.createMessage()` | +| `server.getClientCapabilities()` | `exchange.getClientCapabilities()` | +| `server.getClientInfo()` | `exchange.getClientInfo()` | + +The direct methods are deprecated and will be removed in 0.9.0: + +- `McpSyncServer.listRoots()` +- `McpSyncServer.getClientCapabilities()` +- `McpSyncServer.getClientInfo()` +- `McpSyncServer.createMessage()` +- `McpAsyncServer.listRoots()` +- `McpAsyncServer.getClientCapabilities()` +- `McpAsyncServer.getClientInfo()` +- `McpAsyncServer.createMessage()` + +## Deprecation Notices + +The following components are deprecated in 0.8.0 and will be removed in 0.9.0: + +- `ClientMcpTransport` interface (use `McpClientTransport` instead) +- `ServerMcpTransport` interface (use `McpServerTransport` instead) +- `DefaultMcpSession` class (use `McpClientSession` instead) +- `WebFluxSseServerTransport` class (use `WebFluxSseServerTransportProvider` instead) +- `WebMvcSseServerTransport` class (use `WebMvcSseServerTransportProvider` instead) +- `StdioServerTransport` class (use `StdioServerTransportProvider` instead) +- All `*Registration` classes (use corresponding `*Specification` classes instead) +- Direct server methods for client interaction (use exchange object instead) + +## Migration Examples + +### Example 1: Creating a Server + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +var server = McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + .rootsChangeConsumers(List.of( + roots -> System.out.println("Roots changed: " + roots) + )) + .build(); + +// Get client capabilities directly from server +ClientCapabilities capabilities = server.getClientCapabilities(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +var server = McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, (exchange, args) -> { + // Get client capabilities from exchange + ClientCapabilities capabilities = exchange.getClientCapabilities(); + return new CallToolResult("Result: " + calculate(args)); + }) + .rootsChangeHandlers(List.of( + (exchange, roots) -> System.out.println("Roots changed: " + roots) + )) + .build(); +``` + +### Example 2: Implementing a Tool with Client Interaction + +#### Before (0.7.0): + +```java +McpServerFeatures.SyncToolRegistration tool = new McpServerFeatures.SyncToolRegistration( + new Tool("weather", "Get weather information", schema), + args -> { + String location = (String) args.get("location"); + // Cannot interact with client from here + return new CallToolResult("Weather for " + location + ": Sunny"); + } +); + +var server = McpServer.sync(transport) + .tools(tool) + .build(); + +// Separate call to create a message +CreateMessageResult result = server.createMessage(new CreateMessageRequest(...)); +``` + +#### After (0.8.0): + +```java +McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new Tool("weather", "Get weather information", schema), + (exchange, args) -> { + String location = (String) args.get("location"); + + // Can interact with client directly from the tool handler + CreateMessageResult result = exchange.createMessage(new CreateMessageRequest(...)); + + return new CallToolResult("Weather for " + location + ": " + result.content()); + } +); + +var server = McpServer.sync(transportProvider) + .tools(tool) + .build(); +``` + +### Example 3: Converting Existing Registration Classes + +If you have custom implementations of the registration classes, you can convert them to the new specification classes: + +#### Before (0.7.0): + +```java +McpServerFeatures.AsyncToolRegistration toolReg = new McpServerFeatures.AsyncToolRegistration( + tool, + args -> Mono.just(new CallToolResult("Result")) +); + +McpServerFeatures.AsyncResourceRegistration resourceReg = new McpServerFeatures.AsyncResourceRegistration( + resource, + req -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +#### After (0.8.0): + +```java +// Option 1: Create new specification directly +McpServerFeatures.AsyncToolSpecification toolSpec = new McpServerFeatures.AsyncToolSpecification( + tool, + (exchange, args) -> Mono.just(new CallToolResult("Result")) +); + +// Option 2: Convert from existing registration (during transition) +McpServerFeatures.AsyncToolRegistration oldToolReg = /* existing registration */; +McpServerFeatures.AsyncToolSpecification toolSpec = oldToolReg.toSpecification(); + +// Similarly for resources +McpServerFeatures.AsyncResourceSpecification resourceSpec = new McpServerFeatures.AsyncResourceSpecification( + resource, + (exchange, req) -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +## Architecture Changes + +### Session-Based Architecture + +In 0.8.0, the MCP Java SDK introduces a session-based architecture where each client connection has its own session. This allows for better isolation between clients and more efficient resource management. + +The `McpServerTransportProvider` is responsible for creating `McpServerTransport` instances for each session, and the `McpServerSession` manages the communication with a specific client. + +### Exchange Objects + +The new exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide access to client-specific information and methods. They are passed to handler functions as the first parameter, allowing handlers to interact with the specific client that made the request. + +## Conclusion + +The changes in version 0.8.0 represent a significant architectural improvement to the MCP Java SDK. While they require some code changes, the new design provides a more flexible and maintainable foundation for building MCP applications. + +For assistance with migration or to report issues, please open an issue on the GitHub repository. diff --git a/pom.xml b/pom.xml index 893e5eb9b..f8bc3a9c2 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.8.0-SNAPSHOT + 0.18.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk @@ -57,22 +57,25 @@ 17 17 17 + - 3.26.3 + 3.27.6 5.10.2 - 5.11.0 - 1.20.4 + 5.20.0 + 1.21.4 + 1.17.8 + 1.21.0 2.0.16 1.5.15 - 2.17.0 + 2.19.2 6.2.1 3.11.0 3.1.2 3.5.2 - 3.5.0 + 3.11.2 3.3.0 0.8.10 1.5.0 @@ -93,12 +96,16 @@ 4.2.0 7.1.0 4.1.0 + 2.0.0 mcp-bom mcp + mcp-core + mcp-json-jackson2 + mcp-json mcp-spring/mcp-spring-webflux mcp-spring/mcp-spring-webmvc mcp-test @@ -162,13 +169,23 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + properties + + + + org.apache.maven.plugins maven-surefire-plugin ${maven-surefire-plugin.version} - ${surefireArgLine} - + ${surefireArgLine} -javaagent:${org.mockito:mockito-core:jar} false false @@ -262,6 +279,7 @@ ${maven-javadoc-plugin.version} false + true false none @@ -301,7 +319,7 @@ true central - + true @@ -356,4 +374,4 @@ - \ No newline at end of file +