From 3956a97e7c6d35b38f805e44c449c0d465fdf560 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Mon, 12 Jan 2026 13:11:09 +0100 Subject: [PATCH 1/2] WIP --- go.mod | 14 +- go.sum | 36 +- internal/ghmcp/http.go | 129 ++++ internal/ghmcp/server.go | 682 ------------------ internal/ghmcp/stdio.go | 169 +++++ .../ghmcp/{server_test.go => stdio_test.go} | 21 +- pkg/github/host.go | 203 ++++++ pkg/github/params.go | 385 ++++++++++ pkg/github/params_test.go | 503 +++++++++++++ pkg/github/server.go | 625 ++++++++-------- pkg/github/server_test.go | 496 ------------- pkg/http/server.go | 166 +++++ 12 files changed, 1893 insertions(+), 1536 deletions(-) create mode 100644 internal/ghmcp/http.go delete mode 100644 internal/ghmcp/server.go create mode 100644 internal/ghmcp/stdio.go rename internal/ghmcp/{server_test.go => stdio_test.go} (87%) create mode 100644 pkg/github/host.go create mode 100644 pkg/github/params.go create mode 100644 pkg/github/params_test.go create mode 100644 pkg/http/server.go diff --git a/go.mod b/go.mod index 5322b47ec..1b1af830a 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,10 @@ require ( require ( github.com/aymerick/douceur v0.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/swag v0.21.1 // indirect github.com/gorilla/css v1.0.1 // indirect @@ -22,6 +26,10 @@ require ( github.com/mailru/easyjson v0.7.7 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel v1.39.0 // indirect + go.opentelemetry.io/otel/metric v1.39.0 // indirect + go.opentelemetry.io/otel/trace v1.39.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/net v0.38.0 // indirect @@ -31,13 +39,13 @@ require ( require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-chi/chi v1.5.5 github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/go-querystring v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 @@ -47,9 +55,9 @@ require ( github.com/spf13/pflag v1.0.10 github.com/subosito/gotenv v1.6.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/sys v0.31.0 // indirect + golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.28.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 25cbf7fa9..10ef4e6d4 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,26 @@ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE= +github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= @@ -28,6 +39,8 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -37,7 +50,6 @@ github.com/josephburnett/jd v1.9.2/go.mod h1:bImDr8QXpxMb3SD+w1cDRHp97xP6UwI88xU github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -61,8 +73,8 @@ github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= @@ -96,6 +108,20 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= @@ -104,8 +130,8 @@ golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= diff --git a/internal/ghmcp/http.go b/internal/ghmcp/http.go new file mode 100644 index 000000000..a46a502b0 --- /dev/null +++ b/internal/ghmcp/http.go @@ -0,0 +1,129 @@ +package ghmcp + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/translations" +) + +type HTTPServerConfig struct { + // Version of the server + Version string + + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string + + // GitHub Token to authenticate with the GitHub API + Token string + + // EnabledToolsets is a list of toolsets to enable + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration + EnabledToolsets []string + + // EnabledTools is a list of specific tools to enable (additive to toolsets) + // When specified, these tools are registered in addition to any specified toolset tools + EnabledTools []string + + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + + // Whether to enable dynamic toolsets + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery + DynamicToolsets bool + + // ReadOnly indicates if we should only register read-only tools + ReadOnly bool + + // ExportTranslations indicates if we should export translations + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions + ExportTranslations bool + + // EnableCommandLogging indicates if we should log commands + EnableCommandLogging bool + + // Path to the log file if not stderr + LogFilePath string + + // Content window size + ContentWindowSize int + + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool + + // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. + RepoAccessCacheTTL *time.Duration + + http.Handler +} + +func RunHTTPServer(cfg HTTPServerConfig) error { + // Create app context + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + t, dumpTranslations := translations.TranslationHelper() + + var slogHandler slog.Handler + var logOutput io.Writer + if cfg.LogFilePath != "" { + file, err := os.OpenFile(cfg.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + logOutput = file + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) + } else { + logOutput = os.Stderr + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) + } + logger := slog.New(slogHandler) + logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) + + // Fetch token scopes for scope-based tool filtering (PAT tokens only) + // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. + // Fine-grained PATs and other token types don't support this, so we skip filtering. + var tokenScopes []string + if strings.HasPrefix(cfg.Token, "ghp_") { + fetchedScopes, err := github.FetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) + if err != nil { + logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) + } else { + tokenScopes = fetchedScopes + logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) + } + } else { + logger.Debug("skipping scope filtering for non-PAT token") + } + + ghServer, err := github.NewMCPServer(github.MCPServerConfig{ + Version: cfg.Version, + Host: cfg.Host, + Token: cfg.Token, + EnabledToolsets: cfg.EnabledToolsets, + EnabledTools: cfg.EnabledTools, + EnabledFeatures: cfg.EnabledFeatures, + DynamicToolsets: cfg.DynamicToolsets, + ReadOnly: cfg.ReadOnly, + Translator: t, + ContentWindowSize: cfg.ContentWindowSize, + LockdownMode: cfg.LockdownMode, + Logger: logger, + RepoAccessTTL: cfg.RepoAccessCacheTTL, + TokenScopes: tokenScopes, + }) + if err != nil { + return fmt.Errorf("failed to create MCP server: %w", err) + } + +} diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go deleted file mode 100644 index 165886606..000000000 --- a/internal/ghmcp/server.go +++ /dev/null @@ -1,682 +0,0 @@ -package ghmcp - -import ( - "context" - "fmt" - "io" - "log/slog" - "net/http" - "net/url" - "os" - "os/signal" - "strings" - "syscall" - "time" - - "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/inventory" - "github.com/github/github-mcp-server/pkg/lockdown" - mcplog "github.com/github/github-mcp-server/pkg/log" - "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/scopes" - "github.com/github/github-mcp-server/pkg/translations" - gogithub "github.com/google/go-github/v79/github" - "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/shurcooL/githubv4" -) - -type MCPServerConfig struct { - // Version of the server - Version string - - // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) - Host string - - // GitHub Token to authenticate with the GitHub API - Token string - - // EnabledToolsets is a list of toolsets to enable - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration - EnabledToolsets []string - - // EnabledTools is a list of specific tools to enable (additive to toolsets) - // When specified, these tools are registered in addition to any specified toolset tools - EnabledTools []string - - // EnabledFeatures is a list of feature flags that are enabled - // Items with FeatureFlagEnable matching an entry in this list will be available - EnabledFeatures []string - - // Whether to enable dynamic toolsets - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery - DynamicToolsets bool - - // ReadOnly indicates if we should only offer read-only tools - ReadOnly bool - - // Translator provides translated text for the server tooling - Translator translations.TranslationHelperFunc - - // Content window size - ContentWindowSize int - - // LockdownMode indicates if we should enable lockdown mode - LockdownMode bool - - // Logger is used for logging within the server - Logger *slog.Logger - // RepoAccessTTL overrides the default TTL for repository access cache entries. - RepoAccessTTL *time.Duration - - // TokenScopes contains the OAuth scopes available to the token. - // When non-nil, tools requiring scopes not in this list will be hidden. - // This is used for PAT scope filtering where we can't issue scope challenges. - TokenScopes []string -} - -// githubClients holds all the GitHub API clients created for a server instance. -type githubClients struct { - rest *gogithub.Client - gql *githubv4.Client - gqlHTTP *http.Client // retained for middleware to modify transport - raw *raw.Client - repoAccess *lockdown.RepoAccessCache -} - -// createGitHubClients creates all the GitHub API clients needed by the server. -func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { - // Construct REST client - restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) - restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) - restClient.BaseURL = apiHost.baseRESTURL - restClient.UploadURL = apiHost.uploadURL - - // Construct GraphQL client - // We use NewEnterpriseClient unconditionally since we already parsed the API host - gqlHTTPClient := &http.Client{ - Transport: &bearerAuthTransport{ - transport: http.DefaultTransport, - token: cfg.Token, - }, - } - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - - // Create raw content client (shares REST client's HTTP transport) - rawClient := raw.NewClient(restClient, apiHost.rawURL) - - // Set up repo access cache for lockdown mode - var repoAccessCache *lockdown.RepoAccessCache - if cfg.LockdownMode { - opts := []lockdown.RepoAccessOption{ - lockdown.WithLogger(cfg.Logger.With("component", "lockdown")), - } - if cfg.RepoAccessTTL != nil { - opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) - } - repoAccessCache = lockdown.GetInstance(gqlClient, opts...) - } - - return &githubClients{ - rest: restClient, - gql: gqlClient, - gqlHTTP: gqlHTTPClient, - raw: rawClient, - repoAccess: repoAccessCache, - }, nil -} - -// resolveEnabledToolsets determines which toolsets should be enabled based on config. -// Returns nil for "use defaults", empty slice for "none", or explicit list. -func resolveEnabledToolsets(cfg MCPServerConfig) []string { - enabledToolsets := cfg.EnabledToolsets - - // In dynamic mode, remove "all" and "default" since users enable toolsets on demand - if cfg.DynamicToolsets && enabledToolsets != nil { - enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) - enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) - } - - if enabledToolsets != nil { - return enabledToolsets - } - if cfg.DynamicToolsets { - // Dynamic mode with no toolsets specified: start empty so users enable on demand - return []string{} - } - if len(cfg.EnabledTools) > 0 { - // When specific tools are requested but no toolsets, don't use default toolsets - // This matches the original behavior: --tools=X alone registers only X - return []string{} - } - // nil means "use defaults" in WithToolsets - return nil -} - -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) - if err != nil { - return nil, fmt.Errorf("failed to parse API host: %w", err) - } - - clients, err := createGitHubClients(cfg, apiHost) - if err != nil { - return nil, fmt.Errorf("failed to create GitHub clients: %w", err) - } - - enabledToolsets := resolveEnabledToolsets(cfg) - - // For instruction generation, we need actual toolset names (not nil). - // nil means "use defaults" in inventory, so expand it for instructions. - instructionToolsets := enabledToolsets - if instructionToolsets == nil { - instructionToolsets = github.GetDefaultToolsetIDs() - } - - // Create the MCP server - serverOpts := &mcp.ServerOptions{ - Instructions: github.GenerateInstructions(instructionToolsets), - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { - return clients.rest, nil - }), - } - - // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts - // may be enabled at runtime even if none are registered initially. - if cfg.DynamicToolsets { - serverOpts.Capabilities = &mcp.ServerCapabilities{ - Tools: &mcp.ToolCapabilities{}, - Resources: &mcp.ResourceCapabilities{}, - Prompts: &mcp.PromptCapabilities{}, - } - } - - ghServer := github.NewServer(cfg.Version, serverOpts) - - // Add middlewares - ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) - - // Create dependencies for tool handlers - deps := github.NewBaseDeps( - clients.rest, - clients.gql, - clients.raw, - clients.repoAccess, - cfg.Translator, - github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - cfg.ContentWindowSize, - ) - - // Inject dependencies into context for all tool handlers - ghServer.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - return next(github.ContextWithDeps(ctx, deps), method, req) - } - }) - - // Build and register the tool/resource/prompt inventory - inventoryBuilder := github.NewInventory(cfg.Translator). - WithDeprecatedAliases(github.DeprecatedToolAliases). - WithReadOnly(cfg.ReadOnly). - WithToolsets(enabledToolsets). - WithTools(github.CleanTools(cfg.EnabledTools)). - WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) - - // Apply token scope filtering if scopes are known (for PAT filtering) - if cfg.TokenScopes != nil { - inventoryBuilder = inventoryBuilder.WithFilter(github.CreateToolScopeFilter(cfg.TokenScopes)) - } - - inventory := inventoryBuilder.Build() - - if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { - fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) - } - - // Register GitHub tools/resources/prompts from the inventory. - // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets - // is empty - users enable toolsets at runtime via the dynamic tools below (but can - // enable toolsets or tools explicitly that do need registration). - inventory.RegisterAll(context.Background(), ghServer, deps) - - // Register dynamic toolset management tools (enable/disable) - these are separate - // meta-tools that control the inventory, not part of the inventory itself - if cfg.DynamicToolsets { - registerDynamicTools(ghServer, inventory, deps, cfg.Translator) - } - - return ghServer, nil -} - -// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. -func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps *github.BaseDeps, t translations.TranslationHelperFunc) { - dynamicDeps := github.DynamicToolDependencies{ - Server: server, - Inventory: inventory, - ToolDeps: deps, - T: t, - } - for _, tool := range github.DynamicTools(inventory) { - tool.RegisterFunc(server, dynamicDeps) - } -} - -// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name -// is present in the provided list of enabled features. For the local server, -// this is populated from the --features CLI flag. -func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { - // Build a set for O(1) lookup - featureSet := make(map[string]bool, len(enabledFeatures)) - for _, f := range enabledFeatures { - featureSet[f] = true - } - return func(_ context.Context, flagName string) (bool, error) { - return featureSet[flagName], nil - } -} - -type StdioServerConfig struct { - // Version of the server - Version string - - // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) - Host string - - // GitHub Token to authenticate with the GitHub API - Token string - - // EnabledToolsets is a list of toolsets to enable - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration - EnabledToolsets []string - - // EnabledTools is a list of specific tools to enable (additive to toolsets) - // When specified, these tools are registered in addition to any specified toolset tools - EnabledTools []string - - // EnabledFeatures is a list of feature flags that are enabled - // Items with FeatureFlagEnable matching an entry in this list will be available - EnabledFeatures []string - - // Whether to enable dynamic toolsets - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery - DynamicToolsets bool - - // ReadOnly indicates if we should only register read-only tools - ReadOnly bool - - // ExportTranslations indicates if we should export translations - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions - ExportTranslations bool - - // EnableCommandLogging indicates if we should log commands - EnableCommandLogging bool - - // Path to the log file if not stderr - LogFilePath string - - // Content window size - ContentWindowSize int - - // LockdownMode indicates if we should enable lockdown mode - LockdownMode bool - - // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. - RepoAccessCacheTTL *time.Duration -} - -// RunStdioServer is not concurrent safe. -func RunStdioServer(cfg StdioServerConfig) error { - // Create app context - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - - t, dumpTranslations := translations.TranslationHelper() - - var slogHandler slog.Handler - var logOutput io.Writer - if cfg.LogFilePath != "" { - file, err := os.OpenFile(cfg.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) - if err != nil { - return fmt.Errorf("failed to open log file: %w", err) - } - logOutput = file - slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) - } else { - logOutput = os.Stderr - slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) - } - logger := slog.New(slogHandler) - logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) - - // Fetch token scopes for scope-based tool filtering (PAT tokens only) - // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. - // Fine-grained PATs and other token types don't support this, so we skip filtering. - var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { - fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) - if err != nil { - logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) - } else { - tokenScopes = fetchedScopes - logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) - } - } else { - logger.Debug("skipping scope filtering for non-PAT token") - } - - ghServer, err := NewMCPServer(MCPServerConfig{ - Version: cfg.Version, - Host: cfg.Host, - Token: cfg.Token, - EnabledToolsets: cfg.EnabledToolsets, - EnabledTools: cfg.EnabledTools, - EnabledFeatures: cfg.EnabledFeatures, - DynamicToolsets: cfg.DynamicToolsets, - ReadOnly: cfg.ReadOnly, - Translator: t, - ContentWindowSize: cfg.ContentWindowSize, - LockdownMode: cfg.LockdownMode, - Logger: logger, - RepoAccessTTL: cfg.RepoAccessCacheTTL, - TokenScopes: tokenScopes, - }) - if err != nil { - return fmt.Errorf("failed to create MCP server: %w", err) - } - - if cfg.ExportTranslations { - // Once server is initialized, all translations are loaded - dumpTranslations() - } - - // Start listening for messages - errC := make(chan error, 1) - go func() { - var in io.ReadCloser - var out io.WriteCloser - - in = os.Stdin - out = os.Stdout - - if cfg.EnableCommandLogging { - loggedIO := mcplog.NewIOLogger(in, out, logger) - in, out = loggedIO, loggedIO - } - - // enable GitHub errors in the context - ctx := errors.ContextWithGitHubErrors(ctx) - errC <- ghServer.Run(ctx, &mcp.IOTransport{Reader: in, Writer: out}) - }() - - // Output github-mcp-server string - _, _ = fmt.Fprintf(os.Stderr, "GitHub MCP Server running on stdio\n") - - // Wait for shutdown signal - select { - case <-ctx.Done(): - logger.Info("shutting down server", "signal", "context done") - case err := <-errC: - if err != nil { - logger.Error("error running server", "error", err) - return fmt.Errorf("error running server: %w", err) - } - } - - return nil -} - -type apiHost struct { - baseRESTURL *url.URL - graphqlURL *url.URL - uploadURL *url.URL - rawURL *url.URL -} - -func newDotcomHost() (apiHost, error) { - baseRestURL, err := url.Parse("https://api.github.com/") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err) - } - - gqlURL, err := url.Parse("https://api.github.com/graphql") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err) - } - - uploadURL, err := url.Parse("https://uploads.github.com") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err) - } - - rawURL, err := url.Parse("https://raw.githubusercontent.com/") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: baseRestURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -func newGHECHost(hostname string) (apiHost, error) { - u, err := url.Parse(hostname) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err) - } - - // Unsecured GHEC would be an error - if u.Scheme == "http" { - return apiHost{}, fmt.Errorf("GHEC URL must be HTTPS") - } - - restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err) - } - - gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err) - } - - uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err) - } - - rawURL, err := url.Parse(fmt.Sprintf("https://raw.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: restURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -func newGHESHost(hostname string) (apiHost, error) { - u, err := url.Parse(hostname) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES URL: %w", err) - } - - restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err) - } - - gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err) - } - - // Check if subdomain isolation is enabled - // See https://docs.github.com/en/enterprise-server@3.17/admin/configuring-settings/hardening-security-for-your-enterprise/enabling-subdomain-isolation#about-subdomain-isolation - hasSubdomainIsolation := checkSubdomainIsolation(u.Scheme, u.Hostname()) - - var uploadURL *url.URL - if hasSubdomainIsolation { - // With subdomain isolation: https://uploads.hostname/ - uploadURL, err = url.Parse(fmt.Sprintf("%s://uploads.%s/", u.Scheme, u.Hostname())) - } else { - // Without subdomain isolation: https://hostname/api/uploads/ - uploadURL, err = url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname())) - } - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err) - } - - var rawURL *url.URL - if hasSubdomainIsolation { - // With subdomain isolation: https://raw.hostname/ - rawURL, err = url.Parse(fmt.Sprintf("%s://raw.%s/", u.Scheme, u.Hostname())) - } else { - // Without subdomain isolation: https://hostname/raw/ - rawURL, err = url.Parse(fmt.Sprintf("%s://%s/raw/", u.Scheme, u.Hostname())) - } - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: restURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -// checkSubdomainIsolation detects if GitHub Enterprise Server has subdomain isolation enabled -// by attempting to ping the raw./_ping endpoint on the subdomain. The raw subdomain must always exist for subdomain isolation. -func checkSubdomainIsolation(scheme, hostname string) bool { - subdomainURL := fmt.Sprintf("%s://raw.%s/_ping", scheme, hostname) - - client := &http.Client{ - Timeout: 5 * time.Second, - // Don't follow redirects - we just want to check if the endpoint exists - //nolint:revive // parameters are required by http.Client.CheckRedirect signature - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - resp, err := client.Get(subdomainURL) - if err != nil { - return false - } - defer resp.Body.Close() - - return resp.StatusCode == http.StatusOK -} - -// Note that this does not handle ports yet, so development environments are out. -func parseAPIHost(s string) (apiHost, error) { - if s == "" { - return newDotcomHost() - } - - u, err := url.Parse(s) - if err != nil { - return apiHost{}, fmt.Errorf("could not parse host as URL: %s", s) - } - - if u.Scheme == "" { - return apiHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s) - } - - if strings.HasSuffix(u.Hostname(), "github.com") { - return newDotcomHost() - } - - if strings.HasSuffix(u.Hostname(), "ghe.com") { - return newGHECHost(s) - } - - return newGHESHost(s) -} - -type userAgentTransport struct { - transport http.RoundTripper - agent string -} - -func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("User-Agent", t.agent) - return t.transport.RoundTrip(req) -} - -type bearerAuthTransport struct { - transport http.RoundTripper - token string -} - -func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("Authorization", "Bearer "+t.token) - return t.transport.RoundTrip(req) -} - -func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { - // Ensure the context is cleared of any previous errors - // as context isn't propagated through middleware - ctx = errors.ContextWithGitHubErrors(ctx) - return next(ctx, method, req) - } -} - -func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler { - return func(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, request mcp.Request) (result mcp.Result, err error) { - if method != "initialize" { - return next(ctx, method, request) - } - - initializeRequest, ok := request.(*mcp.InitializeRequest) - if !ok { - return next(ctx, method, request) - } - - message := initializeRequest - userAgent := fmt.Sprintf( - "github-mcp-server/%s (%s/%s)", - cfg.Version, - message.Params.ClientInfo.Name, - message.Params.ClientInfo.Version, - ) - - restClient.UserAgent = userAgent - - gqlHTTPClient.Transport = &userAgentTransport{ - transport: gqlHTTPClient.Transport, - agent: userAgent, - } - - return next(ctx, method, request) - } - } -} - -// fetchTokenScopesForHost fetches the OAuth scopes for a token from the GitHub API. -// It constructs the appropriate API host URL based on the configured host. -func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, error) { - apiHost, err := parseAPIHost(host) - if err != nil { - return nil, fmt.Errorf("failed to parse API host: %w", err) - } - - fetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: apiHost.baseRESTURL.String(), - }) - - return fetcher.FetchTokenScopes(ctx, token) -} diff --git a/internal/ghmcp/stdio.go b/internal/ghmcp/stdio.go new file mode 100644 index 000000000..0bdcae440 --- /dev/null +++ b/internal/ghmcp/stdio.go @@ -0,0 +1,169 @@ +package ghmcp + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/github" + mcplog "github.com/github/github-mcp-server/pkg/log" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type StdioServerConfig struct { + // Version of the server + Version string + + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string + + // GitHub Token to authenticate with the GitHub API + Token string + + // EnabledToolsets is a list of toolsets to enable + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration + EnabledToolsets []string + + // EnabledTools is a list of specific tools to enable (additive to toolsets) + // When specified, these tools are registered in addition to any specified toolset tools + EnabledTools []string + + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + + // Whether to enable dynamic toolsets + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery + DynamicToolsets bool + + // ReadOnly indicates if we should only register read-only tools + ReadOnly bool + + // ExportTranslations indicates if we should export translations + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions + ExportTranslations bool + + // EnableCommandLogging indicates if we should log commands + EnableCommandLogging bool + + // Path to the log file if not stderr + LogFilePath string + + // Content window size + ContentWindowSize int + + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool + + // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. + RepoAccessCacheTTL *time.Duration +} + +// RunStdioServer is not concurrent safe. +func RunStdioServer(cfg StdioServerConfig) error { + // Create app context + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + t, dumpTranslations := translations.TranslationHelper() + + var slogHandler slog.Handler + var logOutput io.Writer + if cfg.LogFilePath != "" { + file, err := os.OpenFile(cfg.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + logOutput = file + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) + } else { + logOutput = os.Stderr + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) + } + logger := slog.New(slogHandler) + logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) + + // Fetch token scopes for scope-based tool filtering (PAT tokens only) + // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. + // Fine-grained PATs and other token types don't support this, so we skip filtering. + var tokenScopes []string + if strings.HasPrefix(cfg.Token, "ghp_") { + fetchedScopes, err := github.FetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) + if err != nil { + logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) + } else { + tokenScopes = fetchedScopes + logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) + } + } else { + logger.Debug("skipping scope filtering for non-PAT token") + } + + ghServer, err := github.NewMCPServer(github.MCPServerConfig{ + Version: cfg.Version, + Host: cfg.Host, + Token: cfg.Token, + EnabledToolsets: cfg.EnabledToolsets, + EnabledTools: cfg.EnabledTools, + EnabledFeatures: cfg.EnabledFeatures, + DynamicToolsets: cfg.DynamicToolsets, + ReadOnly: cfg.ReadOnly, + Translator: t, + ContentWindowSize: cfg.ContentWindowSize, + LockdownMode: cfg.LockdownMode, + Logger: logger, + RepoAccessTTL: cfg.RepoAccessCacheTTL, + TokenScopes: tokenScopes, + }) + if err != nil { + return fmt.Errorf("failed to create MCP server: %w", err) + } + + if cfg.ExportTranslations { + // Once server is initialized, all translations are loaded + dumpTranslations() + } + + // Start listening for messages + errC := make(chan error, 1) + go func() { + var in io.ReadCloser + var out io.WriteCloser + + in = os.Stdin + out = os.Stdout + + if cfg.EnableCommandLogging { + loggedIO := mcplog.NewIOLogger(in, out, logger) + in, out = loggedIO, loggedIO + } + + // enable GitHub errors in the context + ctx := errors.ContextWithGitHubErrors(ctx) + errC <- ghServer.Run(ctx, &mcp.IOTransport{Reader: in, Writer: out}) + }() + + // Output github-mcp-server string + _, _ = fmt.Fprintf(os.Stderr, "GitHub MCP Server running on stdio\n") + + // Wait for shutdown signal + select { + case <-ctx.Done(): + logger.Info("shutting down server", "signal", "context done") + case err := <-errC: + if err != nil { + logger.Error("error running server", "error", err) + return fmt.Errorf("error running server: %w", err) + } + } + + return nil +} diff --git a/internal/ghmcp/server_test.go b/internal/ghmcp/stdio_test.go similarity index 87% rename from internal/ghmcp/server_test.go rename to internal/ghmcp/stdio_test.go index 04c0989d4..6f95cefdc 100644 --- a/internal/ghmcp/server_test.go +++ b/internal/ghmcp/stdio_test.go @@ -3,6 +3,7 @@ package ghmcp import ( "testing" + "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/translations" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,7 +15,7 @@ func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { t.Parallel() // Create a minimal server configuration - cfg := MCPServerConfig{ + cfg := github.MCPServerConfig{ Version: "test", Host: "", // defaults to github.com Token: "test-token", @@ -26,7 +27,7 @@ func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { } // Create the server - server, err := NewMCPServer(cfg) + server, err := github.NewMCPServer(cfg) require.NoError(t, err, "expected server creation to succeed") require.NotNil(t, server, "expected server to be non-nil") @@ -47,12 +48,12 @@ func TestResolveEnabledToolsets(t *testing.T) { tests := []struct { name string - cfg MCPServerConfig + cfg github.MCPServerConfig expectedResult []string }{ { name: "nil toolsets without dynamic mode and no tools - use defaults", - cfg: MCPServerConfig{ + cfg: github.MCPServerConfig{ EnabledToolsets: nil, DynamicToolsets: false, EnabledTools: nil, @@ -61,7 +62,7 @@ func TestResolveEnabledToolsets(t *testing.T) { }, { name: "nil toolsets with dynamic mode - start empty", - cfg: MCPServerConfig{ + cfg: github.MCPServerConfig{ EnabledToolsets: nil, DynamicToolsets: true, EnabledTools: nil, @@ -70,7 +71,7 @@ func TestResolveEnabledToolsets(t *testing.T) { }, { name: "explicit toolsets", - cfg: MCPServerConfig{ + cfg: github.MCPServerConfig{ EnabledToolsets: []string{"repos", "issues"}, DynamicToolsets: false, }, @@ -78,7 +79,7 @@ func TestResolveEnabledToolsets(t *testing.T) { }, { name: "empty toolsets - disable all", - cfg: MCPServerConfig{ + cfg: github.MCPServerConfig{ EnabledToolsets: []string{}, DynamicToolsets: false, }, @@ -86,7 +87,7 @@ func TestResolveEnabledToolsets(t *testing.T) { }, { name: "specific tools without toolsets - no default toolsets", - cfg: MCPServerConfig{ + cfg: github.MCPServerConfig{ EnabledToolsets: nil, DynamicToolsets: false, EnabledTools: []string{"get_me"}, @@ -95,7 +96,7 @@ func TestResolveEnabledToolsets(t *testing.T) { }, { name: "dynamic mode with explicit toolsets removes all and default", - cfg: MCPServerConfig{ + cfg: github.MCPServerConfig{ EnabledToolsets: []string{"all", "repos"}, DynamicToolsets: true, }, @@ -105,7 +106,7 @@ func TestResolveEnabledToolsets(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result := resolveEnabledToolsets(tc.cfg) + result := github.ResolveEnabledToolsets(tc.cfg) assert.Equal(t, tc.expectedResult, result) }) } diff --git a/pkg/github/host.go b/pkg/github/host.go new file mode 100644 index 000000000..9334fd0af --- /dev/null +++ b/pkg/github/host.go @@ -0,0 +1,203 @@ +package github + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/github/github-mcp-server/pkg/scopes" +) + +type apiHost struct { + baseRESTURL *url.URL + graphqlURL *url.URL + uploadURL *url.URL + rawURL *url.URL +} + +func newDotcomHost() (apiHost, error) { + baseRestURL, err := url.Parse("https://api.github.com/") + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err) + } + + gqlURL, err := url.Parse("https://api.github.com/graphql") + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err) + } + + uploadURL, err := url.Parse("https://uploads.github.com") + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err) + } + + rawURL, err := url.Parse("https://raw.githubusercontent.com/") + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) + } + + return apiHost{ + baseRESTURL: baseRestURL, + graphqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + }, nil +} + +func newGHECHost(hostname string) (apiHost, error) { + u, err := url.Parse(hostname) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err) + } + + // Unsecured GHEC would be an error + if u.Scheme == "http" { + return apiHost{}, fmt.Errorf("GHEC URL must be HTTPS") + } + + restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname())) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err) + } + + gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname())) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err) + } + + uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s", u.Hostname())) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err) + } + + rawURL, err := url.Parse(fmt.Sprintf("https://raw.%s/", u.Hostname())) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) + } + + return apiHost{ + baseRESTURL: restURL, + graphqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + }, nil +} + +func newGHESHost(hostname string) (apiHost, error) { + u, err := url.Parse(hostname) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHES URL: %w", err) + } + + restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname())) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err) + } + + gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname())) + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err) + } + + // Check if subdomain isolation is enabled + // See https://docs.github.com/en/enterprise-server@3.17/admin/configuring-settings/hardening-security-for-your-enterprise/enabling-subdomain-isolation#about-subdomain-isolation + hasSubdomainIsolation := checkSubdomainIsolation(u.Scheme, u.Hostname()) + + var uploadURL *url.URL + if hasSubdomainIsolation { + // With subdomain isolation: https://uploads.hostname/ + uploadURL, err = url.Parse(fmt.Sprintf("%s://uploads.%s/", u.Scheme, u.Hostname())) + } else { + // Without subdomain isolation: https://hostname/api/uploads/ + uploadURL, err = url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname())) + } + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err) + } + + var rawURL *url.URL + if hasSubdomainIsolation { + // With subdomain isolation: https://raw.hostname/ + rawURL, err = url.Parse(fmt.Sprintf("%s://raw.%s/", u.Scheme, u.Hostname())) + } else { + // Without subdomain isolation: https://hostname/raw/ + rawURL, err = url.Parse(fmt.Sprintf("%s://%s/raw/", u.Scheme, u.Hostname())) + } + if err != nil { + return apiHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) + } + + return apiHost{ + baseRESTURL: restURL, + graphqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + }, nil +} + +// checkSubdomainIsolation detects if GitHub Enterprise Server has subdomain isolation enabled +// by attempting to ping the raw./_ping endpoint on the subdomain. The raw subdomain must always exist for subdomain isolation. +func checkSubdomainIsolation(scheme, hostname string) bool { + subdomainURL := fmt.Sprintf("%s://raw.%s/_ping", scheme, hostname) + + client := &http.Client{ + Timeout: 5 * time.Second, + // Don't follow redirects - we just want to check if the endpoint exists + //nolint:revive // parameters are required by http.Client.CheckRedirect signature + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(subdomainURL) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// ParseAPIHost parses the given host string and returns the appropriate apiHost configuration. +func ParseAPIHost(s string) (apiHost, error) { + if s == "" { + return newDotcomHost() + } + + u, err := url.Parse(s) + if err != nil { + return apiHost{}, fmt.Errorf("could not parse host as URL: %s", s) + } + + if u.Scheme == "" { + return apiHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s) + } + + if strings.HasSuffix(u.Hostname(), "github.com") { + return newDotcomHost() + } + + if strings.HasSuffix(u.Hostname(), "ghe.com") { + return newGHECHost(s) + } + + return newGHESHost(s) +} + +// FetchTokenScopesForHost fetches the OAuth scopes for a token from the GitHub API. +// It constructs the appropriate API host URL based on the configured host. +func FetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, error) { + apiHost, err := ParseAPIHost(host) + if err != nil { + return nil, fmt.Errorf("failed to parse API host: %w", err) + } + + fetcher := scopes.NewFetcher(scopes.FetcherOptions{ + APIHost: apiHost.baseRESTURL.String(), + }) + + return fetcher.FetchTokenScopes(ctx, token) +} diff --git a/pkg/github/params.go b/pkg/github/params.go new file mode 100644 index 000000000..68705d649 --- /dev/null +++ b/pkg/github/params.go @@ -0,0 +1,385 @@ +package github + +import ( + "fmt" + "strconv" + + "github.com/google/jsonschema-go/jsonschema" +) + +// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. +// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. +func OptionalParamOK[T any, A map[string]any](args A, p string) (value T, ok bool, err error) { + // Check if the parameter is present in the request + val, exists := args[p] + if !exists { + // Not present, return zero value, false, no error + return + } + + // Check if the parameter is of the expected type + value, ok = val.(T) + if !ok { + // Present but wrong type + err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) + ok = true // Set ok to true because the parameter *was* present, even if wrong type + return + } + + // Present and correct type + ok = true + return +} + +// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredParam[T comparable](args map[string]any, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + val, ok := args[p].(T) + if !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + } + + if val == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + return val, nil +} + +// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredInt(args map[string]any, p string) (int, error) { + v, err := RequiredParam[float64](args, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type (float64). +// 3. Checks if the parameter is not empty, i.e: non-zero value. +// 4. Validates that the float64 value can be safely converted to int64 without truncation. +func RequiredBigInt(args map[string]any, p string) (int64, error) { + v, err := RequiredParam[float64](args, p) + if err != nil { + return 0, err + } + + result := int64(v) + // Check if converting back produces the same value to avoid silent truncation + if float64(result) != v { + return 0, fmt.Errorf("parameter %s value %f is too large to fit in int64", p, v) + } + return result, nil +} + +// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalParam[T any](args map[string]any, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return zero, nil + } + + // Check if the parameter is of the expected type + if _, ok := args[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, args[p]) + } + + return args[p].(T), nil +} + +// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalIntParam(args map[string]any, p string) (int, error) { + v, err := OptionalParam[float64](args, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func OptionalIntParamWithDefault(args map[string]any, p string, d int) (int, error) { + v, err := OptionalIntParam(args, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalBoolParam, but it also takes a default value. +func OptionalBoolParamWithDefault(args map[string]any, p string, d bool) (bool, error) { + _, ok := args[p] + v, err := OptionalParam[bool](args, p) + if err != nil { + return false, err + } + if !ok { + return d, nil + } + return v, nil +} + +// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a string +func OptionalStringArrayParam(args map[string]any, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return []string{}, nil + } + + switch v := args[p].(type) { + case nil: + return []string{}, nil + case []string: + return v, nil + case []any: + strSlice := make([]string, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + strSlice[i] = s + } + return strSlice, nil + default: + return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, args[p]) + } +} + +func convertStringSliceToBigIntSlice(s []string) ([]int64, error) { + int64Slice := make([]int64, len(s)) + for i, str := range s { + val, err := convertStringToBigInt(str, 0) + if err != nil { + return nil, fmt.Errorf("failed to convert element %d (%s) to int64: %w", i, str, err) + } + int64Slice[i] = val + } + return int64Slice, nil +} + +func convertStringToBigInt(s string, def int64) (int64, error) { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return def, fmt.Errorf("failed to convert string %s to int64: %w", s, err) + } + return v, nil +} + +// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns an empty slice +// 2. If it is present, iterates the elements, checks each is a string, and converts them to int64 values +func OptionalBigIntArrayParam(args map[string]any, p string) ([]int64, error) { + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return []int64{}, nil + } + + switch v := args[p].(type) { + case nil: + return []int64{}, nil + case []string: + return convertStringSliceToBigIntSlice(v) + case []any: + int64Slice := make([]int64, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + val, err := convertStringToBigInt(s, 0) + if err != nil { + return []int64{}, fmt.Errorf("parameter %s: failed to convert element %d (%s) to int64: %w", p, i, s, err) + } + int64Slice[i] = val + } + return int64Slice, nil + default: + return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, args[p]) + } +} + +// WithPagination adds REST API pagination parameters to a tool. +// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api +func WithPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["page"] = &jsonschema.Schema{ + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + } + + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + return schema +} + +// WithUnifiedPagination adds REST API pagination parameters to a tool. +// GraphQL tools will use this and convert page/perPage to GraphQL cursor parameters internally. +func WithUnifiedPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["page"] = &jsonschema.Schema{ + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + } + + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + schema.Properties["after"] = &jsonschema.Schema{ + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + } + + return schema +} + +// WithCursorPagination adds only cursor-based pagination parameters to a tool (no page parameter). +func WithCursorPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + schema.Properties["after"] = &jsonschema.Schema{ + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + } + + return schema +} + +type PaginationParams struct { + Page int + PerPage int + After string +} + +// OptionalPaginationParams returns the "page", "perPage", and "after" parameters from the request, +// or their default values if not present, "page" default is 1, "perPage" default is 30. +// In future, we may want to make the default values configurable, or even have this +// function returned from `withPagination`, where the defaults are provided alongside +// the min/max values. +func OptionalPaginationParams(args map[string]any) (PaginationParams, error) { + page, err := OptionalIntParamWithDefault(args, "page", 1) + if err != nil { + return PaginationParams{}, err + } + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return PaginationParams{}, err + } + after, err := OptionalParam[string](args, "after") + if err != nil { + return PaginationParams{}, err + } + return PaginationParams{ + Page: page, + PerPage: perPage, + After: after, + }, nil +} + +// OptionalCursorPaginationParams returns the "perPage" and "after" parameters from the request, +// without the "page" parameter, suitable for cursor-based pagination only. +func OptionalCursorPaginationParams(args map[string]any) (CursorPaginationParams, error) { + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return CursorPaginationParams{}, err + } + after, err := OptionalParam[string](args, "after") + if err != nil { + return CursorPaginationParams{}, err + } + return CursorPaginationParams{ + PerPage: perPage, + After: after, + }, nil +} + +type CursorPaginationParams struct { + PerPage int + After string +} + +// ToGraphQLParams converts cursor pagination parameters to GraphQL-specific parameters. +func (p CursorPaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { + if p.PerPage > 100 { + return nil, fmt.Errorf("perPage value %d exceeds maximum of 100", p.PerPage) + } + if p.PerPage < 0 { + return nil, fmt.Errorf("perPage value %d cannot be negative", p.PerPage) + } + first := int32(p.PerPage) + + var after *string + if p.After != "" { + after = &p.After + } + + return &GraphQLPaginationParams{ + First: &first, + After: after, + }, nil +} + +type GraphQLPaginationParams struct { + First *int32 + After *string +} + +// ToGraphQLParams converts REST API pagination parameters to GraphQL-specific parameters. +// This converts page/perPage to first parameter for GraphQL queries. +// If After is provided, it takes precedence over page-based pagination. +func (p PaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { + // Convert to CursorPaginationParams and delegate to avoid duplication + cursor := CursorPaginationParams{ + PerPage: p.PerPage, + After: p.After, + } + return cursor.ToGraphQLParams() +} diff --git a/pkg/github/params_test.go b/pkg/github/params_test.go new file mode 100644 index 000000000..9d7cfe432 --- /dev/null +++ b/pkg/github/params_test.go @@ -0,0 +1,503 @@ +package github + +import ( + "fmt" + "testing" + + "github.com/google/go-github/v79/github" + "github.com/stretchr/testify/assert" +) + +func Test_IsAcceptedError(t *testing.T) { + tests := []struct { + name string + err error + expectAccepted bool + }{ + { + name: "github AcceptedError", + err: &github.AcceptedError{}, + expectAccepted: true, + }, + { + name: "regular error", + err: fmt.Errorf("some other error"), + expectAccepted: false, + }, + { + name: "nil error", + err: nil, + expectAccepted: false, + }, + { + name: "wrapped AcceptedError", + err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), + expectAccepted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isAcceptedError(tc.err) + assert.Equal(t, tc.expectAccepted, result) + }) + } +} + +func Test_RequiredStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := RequiredParam[string](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalParam[string](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_RequiredInt(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := RequiredInt(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} +func Test_OptionalIntParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalIntParam(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParamWithDefault(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + defaultVal int + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + defaultVal: 10, + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + defaultVal: 10, + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalBooleanParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected bool + expectError bool + }{ + { + name: "true value", + params: map[string]interface{}{"flag": true}, + paramName: "flag", + expected: true, + expectError: false, + }, + { + name: "false value", + params: map[string]interface{}{"flag": false}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"flag": "not-a-boolean"}, + paramName: "flag", + expected: false, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalParam[bool](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestOptionalStringArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "flag", + expected: []string{}, + expectError: false, + }, + { + name: "valid any array parameter", + params: map[string]any{ + "flag": []any{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "valid string array parameter", + params: map[string]any{ + "flag": []string{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]any{ + "flag": 1, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + { + name: "wrong slice type parameter", + params: map[string]any{ + "flag": []any{"foo", 2}, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalStringArrayParam(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestOptionalPaginationParams(t *testing.T) { + tests := []struct { + name string + params map[string]any + expected PaginationParams + expectError bool + }{ + { + name: "no pagination parameters, default values", + params: map[string]any{}, + expected: PaginationParams{ + Page: 1, + PerPage: 30, + }, + expectError: false, + }, + { + name: "page parameter, default perPage", + params: map[string]any{ + "page": float64(2), + }, + expected: PaginationParams{ + Page: 2, + PerPage: 30, + }, + expectError: false, + }, + { + name: "perPage parameter, default page", + params: map[string]any{ + "perPage": float64(50), + }, + expected: PaginationParams{ + Page: 1, + PerPage: 50, + }, + expectError: false, + }, + { + name: "page and perPage parameters", + params: map[string]any{ + "page": float64(2), + "perPage": float64(50), + }, + expected: PaginationParams{ + Page: 2, + PerPage: 50, + }, + expectError: false, + }, + { + name: "invalid page parameter", + params: map[string]any{ + "page": "not-a-number", + }, + expected: PaginationParams{}, + expectError: true, + }, + { + name: "invalid perPage parameter", + params: map[string]any{ + "perPage": "not-a-number", + }, + expected: PaginationParams{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalPaginationParams(tc.params) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 8248da58f..acc2e9f5b 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -5,431 +5,376 @@ import ( "encoding/json" "errors" "fmt" - "strconv" + "log/slog" + "net/http" + "os" "strings" + "time" + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/octicons" + "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" - "github.com/google/jsonschema-go/jsonschema" + gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/shurcooL/githubv4" ) -// NewServer creates a new GitHub MCP server with the specified GH client and logger. +type MCPServerConfig struct { + // Version of the server + Version string -func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { - if opts == nil { - opts = &mcp.ServerOptions{} - } + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string - // Create a new MCP server - s := mcp.NewServer(&mcp.Implementation{ - Name: "github-mcp-server", - Title: "GitHub MCP Server", - Version: version, - Icons: octicons.Icons("mark-github"), - }, opts) + // GitHub Token to authenticate with the GitHub API + Token string - return s -} + // EnabledToolsets is a list of toolsets to enable + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration + EnabledToolsets []string -func CompletionsHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { - return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { - switch req.Params.Ref.Type { - case "ref/resource": - if strings.HasPrefix(req.Params.Ref.URI, "repo://") { - return RepositoryResourceCompletionHandler(getClient)(ctx, req) - } - return nil, fmt.Errorf("unsupported resource URI: %s", req.Params.Ref.URI) - case "ref/prompt": - return nil, nil - default: - return nil, fmt.Errorf("unsupported ref type: %s", req.Params.Ref.Type) - } - } -} + // EnabledTools is a list of specific tools to enable (additive to toolsets) + // When specified, these tools are registered in addition to any specified toolset tools + EnabledTools []string -// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. -// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. -func OptionalParamOK[T any, A map[string]any](args A, p string) (value T, ok bool, err error) { - // Check if the parameter is present in the request - val, exists := args[p] - if !exists { - // Not present, return zero value, false, no error - return - } + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string - // Check if the parameter is of the expected type - value, ok = val.(T) - if !ok { - // Present but wrong type - err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) - ok = true // Set ok to true because the parameter *was* present, even if wrong type - return - } + // Whether to enable dynamic toolsets + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery + DynamicToolsets bool - // Present and correct type - ok = true - return -} + // ReadOnly indicates if we should only offer read-only tools + ReadOnly bool -// isAcceptedError checks if the error is an accepted error. -func isAcceptedError(err error) bool { - var acceptedError *github.AcceptedError - return errors.As(err, &acceptedError) -} + // Translator provides translated text for the server tooling + Translator translations.TranslationHelperFunc -// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type. -// 3. Checks if the parameter is not empty, i.e: non-zero value -func RequiredParam[T comparable](args map[string]any, p string) (T, error) { - var zero T - - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return zero, fmt.Errorf("missing required parameter: %s", p) - } + // Content window size + ContentWindowSize int - // Check if the parameter is of the expected type - val, ok := args[p].(T) - if !ok { - return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) - } + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool - if val == zero { - return zero, fmt.Errorf("missing required parameter: %s", p) - } + // Logger is used for logging within the server + Logger *slog.Logger + // RepoAccessTTL overrides the default TTL for repository access cache entries. + RepoAccessTTL *time.Duration - return val, nil + // TokenScopes contains the OAuth scopes available to the token. + // When non-nil, tools requiring scopes not in this list will be hidden. + // This is used for PAT scope filtering where we can't issue scope challenges. + TokenScopes []string } -// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type. -// 3. Checks if the parameter is not empty, i.e: non-zero value -func RequiredInt(args map[string]any, p string) (int, error) { - v, err := RequiredParam[float64](args, p) - if err != nil { - return 0, err - } - return int(v), nil +// githubClients holds all the GitHub API clients created for a server instance. +type githubClients struct { + rest *gogithub.Client + gql *githubv4.Client + gqlHTTP *http.Client // retained for middleware to modify transport + raw *raw.Client + repoAccess *lockdown.RepoAccessCache } -// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type (float64). -// 3. Checks if the parameter is not empty, i.e: non-zero value. -// 4. Validates that the float64 value can be safely converted to int64 without truncation. -func RequiredBigInt(args map[string]any, p string) (int64, error) { - v, err := RequiredParam[float64](args, p) - if err != nil { - return 0, err +// createGitHubClients creates all the GitHub API clients needed by the server. +func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { + // Construct REST client + restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) + restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) + restClient.BaseURL = apiHost.baseRESTURL + restClient.UploadURL = apiHost.uploadURL + + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host + gqlHTTPClient := &http.Client{ + Transport: &bearerAuthTransport{ + transport: http.DefaultTransport, + token: cfg.Token, + }, + } + gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) + + // Create raw content client (shares REST client's HTTP transport) + rawClient := raw.NewClient(restClient, apiHost.rawURL) + + // Set up repo access cache for lockdown mode + var repoAccessCache *lockdown.RepoAccessCache + if cfg.LockdownMode { + opts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(cfg.Logger.With("component", "lockdown")), + } + if cfg.RepoAccessTTL != nil { + opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) + } + repoAccessCache = lockdown.GetInstance(gqlClient, opts...) } - result := int64(v) - // Check if converting back produces the same value to avoid silent truncation - if float64(result) != v { - return 0, fmt.Errorf("parameter %s value %f is too large to fit in int64", p, v) - } - return result, nil + return &githubClients{ + rest: restClient, + gql: gqlClient, + gqlHTTP: gqlHTTPClient, + raw: rawClient, + repoAccess: repoAccessCache, + }, nil } -// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalParam[T any](args map[string]any, p string) (T, error) { - var zero T +// ResolveEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func ResolveEnabledToolsets(cfg MCPServerConfig) []string { + enabledToolsets := cfg.EnabledToolsets - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return zero, nil + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if cfg.DynamicToolsets && enabledToolsets != nil { + enabledToolsets = RemoveToolset(enabledToolsets, string(ToolsetMetadataAll.ID)) + enabledToolsets = RemoveToolset(enabledToolsets, string(ToolsetMetadataDefault.ID)) } - // Check if the parameter is of the expected type - if _, ok := args[p].(T); !ok { - return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, args[p]) + if enabledToolsets != nil { + return enabledToolsets } - - return args[p].(T), nil -} - -// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalIntParam(args map[string]any, p string) (int, error) { - v, err := OptionalParam[float64](args, p) - if err != nil { - return 0, err + if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} + } + if len(cfg.EnabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} } - return int(v), nil + // nil means "use defaults" in WithToolsets + return nil } -// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalIntParam, but it also takes a default value. -func OptionalIntParamWithDefault(args map[string]any, p string, d int) (int, error) { - v, err := OptionalIntParam(args, p) +func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { + apiHost, err := ParseAPIHost(cfg.Host) if err != nil { - return 0, err + return nil, fmt.Errorf("failed to parse API host: %w", err) } - if v == 0 { - return d, nil - } - return v, nil -} -// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalBoolParam, but it also takes a default value. -func OptionalBoolParamWithDefault(args map[string]any, p string, d bool) (bool, error) { - _, ok := args[p] - v, err := OptionalParam[bool](args, p) + clients, err := createGitHubClients(cfg, apiHost) if err != nil { - return false, err - } - if !ok { - return d, nil + return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } - return v, nil -} -// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, iterates the elements and checks each is a string -func OptionalStringArrayParam(args map[string]any, p string) ([]string, error) { - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return []string{}, nil + enabledToolsets := ResolveEnabledToolsets(cfg) + + // For instruction generation, we need actual toolset names (not nil). + // nil means "use defaults" in inventory, so expand it for instructions. + instructionToolsets := enabledToolsets + if instructionToolsets == nil { + instructionToolsets = GetDefaultToolsetIDs() } - switch v := args[p].(type) { - case nil: - return []string{}, nil - case []string: - return v, nil - case []any: - strSlice := make([]string, len(v)) - for i, v := range v { - s, ok := v.(string) - if !ok { - return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) - } - strSlice[i] = s - } - return strSlice, nil - default: - return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, args[p]) + // Create the MCP server + serverOpts := &mcp.ServerOptions{ + Instructions: GenerateInstructions(instructionToolsets), + Logger: cfg.Logger, + CompletionHandler: CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { + return clients.rest, nil + }), } -} -func convertStringSliceToBigIntSlice(s []string) ([]int64, error) { - int64Slice := make([]int64, len(s)) - for i, str := range s { - val, err := convertStringToBigInt(str, 0) - if err != nil { - return nil, fmt.Errorf("failed to convert element %d (%s) to int64: %w", i, str, err) + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.Capabilities = &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{}, + Resources: &mcp.ResourceCapabilities{}, + Prompts: &mcp.PromptCapabilities{}, } - int64Slice[i] = val } - return int64Slice, nil -} -func convertStringToBigInt(s string, def int64) (int64, error) { - v, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return def, fmt.Errorf("failed to convert string %s to int64: %w", s, err) - } - return v, nil -} + ghServer := NewServer(cfg.Version, serverOpts) -// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns an empty slice -// 2. If it is present, iterates the elements, checks each is a string, and converts them to int64 values -func OptionalBigIntArrayParam(args map[string]any, p string) ([]int64, error) { - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return []int64{}, nil - } + // Add middlewares + ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) - switch v := args[p].(type) { - case nil: - return []int64{}, nil - case []string: - return convertStringSliceToBigIntSlice(v) - case []any: - int64Slice := make([]int64, len(v)) - for i, v := range v { - s, ok := v.(string) - if !ok { - return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) - } - val, err := convertStringToBigInt(s, 0) - if err != nil { - return []int64{}, fmt.Errorf("parameter %s: failed to convert element %d (%s) to int64: %w", p, i, s, err) - } - int64Slice[i] = val + // Create dependencies for tool handlers + deps := NewBaseDeps( + clients.rest, + clients.gql, + clients.raw, + clients.repoAccess, + cfg.Translator, + FeatureFlags{LockdownMode: cfg.LockdownMode}, + cfg.ContentWindowSize, + ) + + // Inject dependencies into context for all tool handlers + ghServer.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return next(ContextWithDeps(ctx, deps), method, req) } - return int64Slice, nil - default: - return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, args[p]) - } -} + }) -// WithPagination adds REST API pagination parameters to a tool. -// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api -func WithPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["page"] = &jsonschema.Schema{ - Type: "number", - Description: "Page number for pagination (min 1)", - Minimum: jsonschema.Ptr(1.0), - } + // Build and register the tool/resource/prompt inventory + inventoryBuilder := NewInventory(cfg.Translator). + WithDeprecatedAliases(DeprecatedToolAliases). + WithReadOnly(cfg.ReadOnly). + WithToolsets(enabledToolsets). + WithTools(CleanTools(cfg.EnabledTools)). + WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), + // Apply token scope filtering if scopes are known (for PAT filtering) + if cfg.TokenScopes != nil { + inventoryBuilder = inventoryBuilder.WithFilter(CreateToolScopeFilter(cfg.TokenScopes)) } - return schema -} + inventory := inventoryBuilder.Build() -// WithUnifiedPagination adds REST API pagination parameters to a tool. -// GraphQL tools will use this and convert page/perPage to GraphQL cursor parameters internally. -func WithUnifiedPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["page"] = &jsonschema.Schema{ - Type: "number", - Description: "Page number for pagination (min 1)", - Minimum: jsonschema.Ptr(1.0), + if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) } - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), - } + // Register GitHub tools/resources/prompts from the inventory. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + inventory.RegisterAll(context.Background(), ghServer, deps) - schema.Properties["after"] = &jsonschema.Schema{ - Type: "string", - Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the inventory, not part of the inventory itself + if cfg.DynamicToolsets { + registerDynamicTools(ghServer, inventory, deps, cfg.Translator) } - return schema + return ghServer, nil } -// WithCursorPagination adds only cursor-based pagination parameters to a tool (no page parameter). -func WithCursorPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps *BaseDeps, t translations.TranslationHelperFunc) { + dynamicDeps := DynamicToolDependencies{ + Server: server, + Inventory: inventory, + ToolDeps: deps, + T: t, } - - schema.Properties["after"] = &jsonschema.Schema{ - Type: "string", - Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + for _, tool := range DynamicTools(inventory) { + tool.RegisterFunc(server, dynamicDeps) } +} - return schema +// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name +// is present in the provided list of enabled features. For the local server, +// this is populated from the --features CLI flag. +func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { + // Build a set for O(1) lookup + featureSet := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + featureSet[f] = true + } + return func(_ context.Context, flagName string) (bool, error) { + return featureSet[flagName], nil + } } -type PaginationParams struct { - Page int - PerPage int - After string +type userAgentTransport struct { + transport http.RoundTripper + agent string } -// OptionalPaginationParams returns the "page", "perPage", and "after" parameters from the request, -// or their default values if not present, "page" default is 1, "perPage" default is 30. -// In future, we may want to make the default values configurable, or even have this -// function returned from `withPagination`, where the defaults are provided alongside -// the min/max values. -func OptionalPaginationParams(args map[string]any) (PaginationParams, error) { - page, err := OptionalIntParamWithDefault(args, "page", 1) - if err != nil { - return PaginationParams{}, err - } - perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) - if err != nil { - return PaginationParams{}, err - } - after, err := OptionalParam[string](args, "after") - if err != nil { - return PaginationParams{}, err - } - return PaginationParams{ - Page: page, - PerPage: perPage, - After: after, - }, nil +func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("User-Agent", t.agent) + return t.transport.RoundTrip(req) } -// OptionalCursorPaginationParams returns the "perPage" and "after" parameters from the request, -// without the "page" parameter, suitable for cursor-based pagination only. -func OptionalCursorPaginationParams(args map[string]any) (CursorPaginationParams, error) { - perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) - if err != nil { - return CursorPaginationParams{}, err - } - after, err := OptionalParam[string](args, "after") - if err != nil { - return CursorPaginationParams{}, err - } - return CursorPaginationParams{ - PerPage: perPage, - After: after, - }, nil +type bearerAuthTransport struct { + transport http.RoundTripper + token string } -type CursorPaginationParams struct { - PerPage int - After string +func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+t.token) + return t.transport.RoundTrip(req) } -// ToGraphQLParams converts cursor pagination parameters to GraphQL-specific parameters. -func (p CursorPaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { - if p.PerPage > 100 { - return nil, fmt.Errorf("perPage value %d exceeds maximum of 100", p.PerPage) +func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + // Ensure the context is cleared of any previous errors + // as context isn't propagated through middleware + ctx = ghErrors.ContextWithGitHubErrors(ctx) + return next(ctx, method, req) } - if p.PerPage < 0 { - return nil, fmt.Errorf("perPage value %d cannot be negative", p.PerPage) +} + +func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, request mcp.Request) (result mcp.Result, err error) { + if method != "initialize" { + return next(ctx, method, request) + } + + initializeRequest, ok := request.(*mcp.InitializeRequest) + if !ok { + return next(ctx, method, request) + } + + message := initializeRequest + userAgent := fmt.Sprintf( + "github-mcp-server/%s (%s/%s)", + cfg.Version, + message.Params.ClientInfo.Name, + message.Params.ClientInfo.Version, + ) + + restClient.UserAgent = userAgent + + gqlHTTPClient.Transport = &userAgentTransport{ + transport: gqlHTTPClient.Transport, + agent: userAgent, + } + + return next(ctx, method, request) + } } - first := int32(p.PerPage) +} - var after *string - if p.After != "" { - after = &p.After +// NewServer creates a new GitHub MCP server with the specified GH client and logger. +func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { + if opts == nil { + opts = &mcp.ServerOptions{} } - return &GraphQLPaginationParams{ - First: &first, - After: after, - }, nil -} + // Create a new MCP server + s := mcp.NewServer(&mcp.Implementation{ + Name: "github-mcp-server", + Title: "GitHub MCP Server", + Version: version, + Icons: octicons.Icons("mark-github"), + }, opts) -type GraphQLPaginationParams struct { - First *int32 - After *string + return s } -// ToGraphQLParams converts REST API pagination parameters to GraphQL-specific parameters. -// This converts page/perPage to first parameter for GraphQL queries. -// If After is provided, it takes precedence over page-based pagination. -func (p PaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { - // Convert to CursorPaginationParams and delegate to avoid duplication - cursor := CursorPaginationParams{ - PerPage: p.PerPage, - After: p.After, +func CompletionsHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + switch req.Params.Ref.Type { + case "ref/resource": + if strings.HasPrefix(req.Params.Ref.URI, "repo://") { + return RepositoryResourceCompletionHandler(getClient)(ctx, req) + } + return nil, fmt.Errorf("unsupported resource URI: %s", req.Params.Ref.URI) + case "ref/prompt": + return nil, nil + default: + return nil, fmt.Errorf("unsupported ref type: %s", req.Params.Ref.Type) + } } - return cursor.ToGraphQLParams() +} + +// isAcceptedError checks if the error is an accepted error. +func isAcceptedError(err error) bool { + var acceptedError *github.AcceptedError + return errors.As(err, &acceptedError) } func MarshalledTextResult(v any) *mcp.CallToolResult { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index a59cd9a93..66f3a28e0 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/http" - "testing" "time" "github.com/github/github-mcp-server/pkg/lockdown" @@ -14,7 +13,6 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" - "github.com/stretchr/testify/assert" ) // stubDeps is a test helper that implements ToolDependencies with configurable behavior. @@ -100,497 +98,3 @@ func badRequestHandler(msg string) http.HandlerFunc { http.Error(w, string(b), http.StatusBadRequest) } } - -func Test_IsAcceptedError(t *testing.T) { - tests := []struct { - name string - err error - expectAccepted bool - }{ - { - name: "github AcceptedError", - err: &github.AcceptedError{}, - expectAccepted: true, - }, - { - name: "regular error", - err: fmt.Errorf("some other error"), - expectAccepted: false, - }, - { - name: "nil error", - err: nil, - expectAccepted: false, - }, - { - name: "wrapped AcceptedError", - err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), - expectAccepted: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := isAcceptedError(tc.err) - assert.Equal(t, tc.expectAccepted, result) - }) - } -} - -func Test_RequiredStringParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected string - expectError bool - }{ - { - name: "valid string parameter", - params: map[string]interface{}{"name": "test-value"}, - paramName: "name", - expected: "test-value", - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "name", - expected: "", - expectError: true, - }, - { - name: "empty string parameter", - params: map[string]interface{}{"name": ""}, - paramName: "name", - expected: "", - expectError: true, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"name": 123}, - paramName: "name", - expected: "", - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := RequiredParam[string](tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalStringParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected string - expectError bool - }{ - { - name: "valid string parameter", - params: map[string]interface{}{"name": "test-value"}, - paramName: "name", - expected: "test-value", - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "name", - expected: "", - expectError: false, - }, - { - name: "empty string parameter", - params: map[string]interface{}{"name": ""}, - paramName: "name", - expected: "", - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"name": 123}, - paramName: "name", - expected: "", - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalParam[string](tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_RequiredInt(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - expected: 0, - expectError: true, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := RequiredInt(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} -func Test_OptionalIntParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - expected: 0, - expectError: false, - }, - { - name: "zero value", - params: map[string]interface{}{"count": float64(0)}, - paramName: "count", - expected: 0, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalIntParam(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalNumberParamWithDefault(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - defaultVal int - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - defaultVal: 10, - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - defaultVal: 10, - expected: 10, - expectError: false, - }, - { - name: "zero value", - params: map[string]interface{}{"count": float64(0)}, - paramName: "count", - defaultVal: 10, - expected: 10, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - defaultVal: 10, - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalBooleanParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected bool - expectError bool - }{ - { - name: "true value", - params: map[string]interface{}{"flag": true}, - paramName: "flag", - expected: true, - expectError: false, - }, - { - name: "false value", - params: map[string]interface{}{"flag": false}, - paramName: "flag", - expected: false, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "flag", - expected: false, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"flag": "not-a-boolean"}, - paramName: "flag", - expected: false, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalParam[bool](tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func TestOptionalStringArrayParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected []string - expectError bool - }{ - { - name: "parameter not in request", - params: map[string]any{}, - paramName: "flag", - expected: []string{}, - expectError: false, - }, - { - name: "valid any array parameter", - params: map[string]any{ - "flag": []any{"v1", "v2"}, - }, - paramName: "flag", - expected: []string{"v1", "v2"}, - expectError: false, - }, - { - name: "valid string array parameter", - params: map[string]any{ - "flag": []string{"v1", "v2"}, - }, - paramName: "flag", - expected: []string{"v1", "v2"}, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]any{ - "flag": 1, - }, - paramName: "flag", - expected: []string{}, - expectError: true, - }, - { - name: "wrong slice type parameter", - params: map[string]any{ - "flag": []any{"foo", 2}, - }, - paramName: "flag", - expected: []string{}, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalStringArrayParam(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func TestOptionalPaginationParams(t *testing.T) { - tests := []struct { - name string - params map[string]any - expected PaginationParams - expectError bool - }{ - { - name: "no pagination parameters, default values", - params: map[string]any{}, - expected: PaginationParams{ - Page: 1, - PerPage: 30, - }, - expectError: false, - }, - { - name: "page parameter, default perPage", - params: map[string]any{ - "page": float64(2), - }, - expected: PaginationParams{ - Page: 2, - PerPage: 30, - }, - expectError: false, - }, - { - name: "perPage parameter, default page", - params: map[string]any{ - "perPage": float64(50), - }, - expected: PaginationParams{ - Page: 1, - PerPage: 50, - }, - expectError: false, - }, - { - name: "page and perPage parameters", - params: map[string]any{ - "page": float64(2), - "perPage": float64(50), - }, - expected: PaginationParams{ - Page: 2, - PerPage: 50, - }, - expectError: false, - }, - { - name: "invalid page parameter", - params: map[string]any{ - "page": "not-a-number", - }, - expected: PaginationParams{}, - expectError: true, - }, - { - name: "invalid perPage parameter", - params: map[string]any{ - "perPage": "not-a-number", - }, - expected: PaginationParams{}, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalPaginationParams(tc.params) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} diff --git a/pkg/http/server.go b/pkg/http/server.go new file mode 100644 index 000000000..8c2524d49 --- /dev/null +++ b/pkg/http/server.go @@ -0,0 +1,166 @@ +//go:build ignore + +package http + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type HTTPHandlerConfig struct { + Version string +} + +type MCPHTTPHandler struct { + config HTTPHandlerConfig + serverDependencies serverDependencies +} + +// serverDependencies holds the dependencies required by the MCP server. +type serverDependencies struct { + getClient github.GetClientFn + getGQClient github.GetGQLClientFn + getRawClient raw.GetRawClientFn + t translations.TranslationHelperFunc +} + +func NewMCPHTTPHandler(ctx context.Context, router chi.Router) (*MCPHTTPHandler, error) { + return &MCPHTTPHandler{}, nil +} + +type reqOptions struct { + readonly bool + toolsetRoute bool +} + +func (h *MCPHTTPHandler) RegisterRoutes(ctx context.Context, r chi.Router) error { + spanNameFormatter := otelhttp.WithSpanNameFormatter( + func(operation string, r *http.Request) string { + return r.Method + " " + routePattern(r) + }, + ) + + mount := func(pattern string, opts reqOptions) { + r.Mount(pattern, otelhttp.NewHandler(h.perRequestHTTPHandler(opts), pattern, spanNameFormatter)) + } + + // Always mount the root route as before + mount("/", reqOptions{}) + + // Use chi params for toolset and readonly + mount("/x/{toolset}", reqOptions{toolsetRoute: true}) + mount("/x/{toolset}/readonly", reqOptions{toolsetRoute: true, readonly: true}) + mount("/readonly", reqOptions{readonly: true}) + + return nil +} + +func (h *MCPHTTPHandler) perRequestHTTPHandler(opts reqOptions) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + // Initialize GitHub errors context early in the request lifecycle + ctx := r.Context() + ctx = ghErrors.ContextWithGitHubErrors(ctx) + r = r.WithContext(ctx) + + // Set the toolset and readonly params in the context + if opts.readonly { + r = setReadOnly(r) + } + if opts.toolsetRoute { + r = setToolset(r) + } + + //nolint:contextcheck // ctx returned from newMCPServerForRequest contains injected deps via ContextWithDeps + ctx, svr, _, err := m.newMCPServerForRequest(r) + if err != nil { + message := fmt.Sprintf("failed to create MCP server, %s", err.Error()) + status := http.StatusInternalServerError + http.Error(w, message, status) + return + } + + // Then pass to the main handler logic + h.serveHTTP(w, r) + }) +} + +func (h *MCPHTTPHandler) newMCPServerForRequest(r *http.Request) (context.Context, *mcp.Server, *inventory.Inventory, error) { + ctx := r.Context() + + getClient := h.serverDependencies.getClient + getGQClient := h.serverDependencies.getGQClient + getRawClient := h.serverDependencies.getRawClient + t := h.serverDependencies.t + + deps := github.NewBaseDeps( + clients.rest, + clients.gql, + clients.raw, + clients.repoAccess, + cfg.Translator, + github.FeatureFlags{LockdownMode: cfg.LockdownMode}, + cfg.ContentWindowSize, + ) + + // Build the registry with all configuration applied + builder := inventory.NewBuilder(). + SetTools(github.AllTools(t)). + SetResources(github.AllResources(t)). + SetPrompts(github.AllPrompts(t)) + + inv := builder.Build() + + enabledToolsetIDs := inv.EnabledToolsetIDs() + enabledToolsets := make([]string, len(enabledToolsetIDs)) + for i, id := range enabledToolsetIDs { + enabledToolsets[i] = string(id) + } + + svrOpts := mcp.ServerOptions{} + + svr := github.NewServer(h.config.Version, &svrOpts) + + return ctx, svr, inv, nil +} + +// routePattern returns the route pattern for the given HTTP request. +// It retrieves the route pattern from the chi RouteContext, falling back to the request's URL path +// if the route context or pattern is unavailable. If the pattern is "/*", it is normalized to "/". +// Additionally, it replaces curly braces in the pattern with a colon prefix to ensure compatibility +// with Datadog, which does not accept curly braces in route patterns. +func routePattern(r *http.Request) string { + rc := chi.RouteContext(r.Context()) + if rc == nil || rc.RoutePattern() == "" { + return r.URL.Path + } + + routePattern := rc.RoutePattern() + + // Normalise /* to / as even without an explicit catch all route, + // mounting at / can result in /* as the route pattern for unmatched paths. + if routePattern == "/*" { + routePattern = "/" + } + + // Replace { and } with : because Datadog doesn't like them. + // Before: /x/{toolset} + // After: /x/:toolset + routePattern = strings.ReplaceAll(routePattern, "{", ":") + routePattern = strings.ReplaceAll(routePattern, "}", "") + + return routePattern +} From cf6bab9908b9e2c45e4ba6a0029421a832190602 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Mon, 12 Jan 2026 17:43:48 +0100 Subject: [PATCH 2/2] WIP --- cmd/github-mcp-server/main.go | 15 ++++ internal/ghmcp/http.go | 55 +++++++++------ pkg/github/dependencies.go | 42 +++++++++++ pkg/github/server.go | 6 +- pkg/http/headers/headers.go | 26 +++++++ pkg/http/middleware/token.go | 127 ++++++++++++++++++++++++++++++++++ 6 files changed, 246 insertions(+), 25 deletions(-) create mode 100644 pkg/http/headers/headers.go create mode 100644 pkg/http/middleware/token.go diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index cfb68be4e..f43279f53 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -88,6 +88,20 @@ var ( return ghmcp.RunStdioServer(stdioServerConfig) }, } + + httpCmd = &cobra.Command{ + Use: "http", + Short: "Start HTTP server", + Long: `Start a server that communicates via HTTP using JSON-RPC messages.`, + RunE: func(_ *cobra.Command, _ []string) error { + config := ghmcp.HTTPServerConfig{ + Version: version, + Host: viper.GetString("host"), + } + + return ghmcp.RunHTTPServer(config) + }, + } ) func init() { @@ -126,6 +140,7 @@ func init() { // Add subcommands rootCmd.AddCommand(stdioCmd) + rootCmd.AddCommand(httpCmd) } func initConfig() { diff --git a/internal/ghmcp/http.go b/internal/ghmcp/http.go index a46a502b0..d2728b39e 100644 --- a/internal/ghmcp/http.go +++ b/internal/ghmcp/http.go @@ -8,12 +8,13 @@ import ( "net/http" "os" "os/signal" - "strings" "syscall" "time" "github.com/github/github-mcp-server/pkg/github" + httpMiddleware "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/translations" + "github.com/modelcontextprotocol/go-sdk/mcp" ) type HTTPServerConfig struct { @@ -23,9 +24,6 @@ type HTTPServerConfig struct { // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) Host string - // GitHub Token to authenticate with the GitHub API - Token string - // EnabledToolsets is a list of toolsets to enable // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration EnabledToolsets []string @@ -90,26 +88,9 @@ func RunHTTPServer(cfg HTTPServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) - // Fetch token scopes for scope-based tool filtering (PAT tokens only) - // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. - // Fine-grained PATs and other token types don't support this, so we skip filtering. - var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { - fetchedScopes, err := github.FetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) - if err != nil { - logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) - } else { - tokenScopes = fetchedScopes - logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) - } - } else { - logger.Debug("skipping scope filtering for non-PAT token") - } - ghServer, err := github.NewMCPServer(github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, - Token: cfg.Token, EnabledToolsets: cfg.EnabledToolsets, EnabledTools: cfg.EnabledTools, EnabledFeatures: cfg.EnabledFeatures, @@ -120,10 +101,40 @@ func RunHTTPServer(cfg HTTPServerConfig) error { LockdownMode: cfg.LockdownMode, Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, - TokenScopes: tokenScopes, }) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return ghServer + }, &mcp.StreamableHTTPOptions{}) + + httpSvr := http.Server{ + Addr: ":8082", + Handler: httpMiddleware.ExtractUserToken()(handler), + } + + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + logger.Info("shutting down server") + if err := httpSvr.Shutdown(shutdownCtx); err != nil { + logger.Error("error during server shutdown", "error", err) + } + }() + + if cfg.ExportTranslations { + // Once server is initialized, all translations are loaded + dumpTranslations() + } + + logger.Info("HTTP server listening on :8082") + if err := httpSvr.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("HTTP server error: %w", err) + } + + logger.Info("server stopped gracefully") + return nil } diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index b41bf0b87..c0e31d8f2 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -4,6 +4,7 @@ import ( "context" "errors" + "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" @@ -190,3 +191,44 @@ func NewToolFromHandler( st.AcceptedScopes = scopes.ExpandScopes(requiredScopes...) return st } + +type RequestDep struct { + // Pre-created clients + Client *gogithub.Client + GQLClient *githubv4.Client + RawClient *raw.Client + + // Static dependencies + RepoAccessCache *lockdown.RepoAccessCache + T translations.TranslationHelperFunc + Flags FeatureFlags + ContentWindowSize int +} + +// GetClient implements ToolDependencies. +func (d RequestDep) GetClient(ctx context.Context) (*gogithub.Client, error) { + token := middleware.Token(ctx) + return d.Client.WithAuthToken(token.Token), nil +} + +// GetGQLClient implements ToolDependencies. +func (d RequestDep) GetGQLClient(_ context.Context) (*githubv4.Client, error) { + return d.GQLClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d RequestDep) GetRawClient(_ context.Context) (*raw.Client, error) { + return d.RawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d RequestDep) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } + +// GetT implements ToolDependencies. +func (d RequestDep) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d RequestDep) GetFlags() FeatureFlags { return d.Flags } + +// GetContentWindowSize implements ToolDependencies. +func (d RequestDep) GetContentWindowSize() int { return d.ContentWindowSize } diff --git a/pkg/github/server.go b/pkg/github/server.go index acc2e9f5b..3b2993440 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -124,9 +124,9 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, }, nil } -// ResolveEnabledToolsets determines which toolsets should be enabled based on config. +// resolveEnabledToolsets determines which toolsets should be enabled based on config. // Returns nil for "use defaults", empty slice for "none", or explicit list. -func ResolveEnabledToolsets(cfg MCPServerConfig) []string { +func resolveEnabledToolsets(cfg MCPServerConfig) []string { enabledToolsets := cfg.EnabledToolsets // In dynamic mode, remove "all" and "default" since users enable toolsets on demand @@ -162,7 +162,7 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } - enabledToolsets := ResolveEnabledToolsets(cfg) + enabledToolsets := resolveEnabledToolsets(cfg) // For instruction generation, we need actual toolset names (not nil). // nil means "use defaults" in inventory, so expand it for instructions. diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go new file mode 100644 index 000000000..7b47eeaf6 --- /dev/null +++ b/pkg/http/headers/headers.go @@ -0,0 +1,26 @@ +package headers + +const ( + // AuthorizationHeader is a standard HTTP Header. + AuthorizationHeader = "Authorization" + // ContentTypeHeader is a standard HTTP Header. + ContentTypeHeader = "Content-Type" + // AcceptHeader is a standard HTTP Header. + AcceptHeader = "Accept" + // UserAgentHeader is a standard HTTP Header. + UserAgentHeader = "User-Agent" + + // ContentTypeJSON is the standard MIME type for JSON. + ContentTypeJSON = "application/json" + // ContentTypeEventStream is the standard MIME type for Event Streams. + ContentTypeEventStream = "text/event-stream" + + // ForwardedForHeader is a standard HTTP Header used to forward the originating IP address of a client. + ForwardedForHeader = "X-Forwarded-For" + + // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. + RealIPHeader = "X-Real-IP" + + // RequestHmacHeader is used to authenticate requests to the Raw API. + RequestHmacHeader = "Request-Hmac" +) diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go new file mode 100644 index 000000000..8208abaf1 --- /dev/null +++ b/pkg/http/middleware/token.go @@ -0,0 +1,127 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + "regexp" + "strings" + + httpheaders "github.com/github/github-mcp-server/pkg/http/headers" +) + +type authType int + +const ( + authTypeUnknown authType = iota + authTypeIDE + authTypeGhToken +) + +var ( + errMissingAuthorizationHeader = errors.New("missing Authorization header") + errBadAuthorizationHeader = errors.New("bad Authorization header format") + errUnsupportedAuthorizationHeader = errors.New("unsupported Authorization header format") +) + +var supportedThirdPartyTokenPrefixes = []string{ + "ghp_", // Personal access token (classic) + "github_pat_", // Fine-grained personal access token + "gho_", // OAuth access token + "ghu_", // User access token for a GitHub App + "ghs_", // Installation access token for a GitHub App (a.k.a. server-to-server token) +} + +// oldPatternRegexp is the regular expression for the old pattern of the token. +// Until 2021, GitHub API tokens did not have an identifiable prefix. They +// were 40 characters long and only contained the characters a-f and 0-9. +var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) + +type tokenCtxKey string + +var tokenContextKey tokenCtxKey = "tokenctx" + +type TokenData struct { + Token string +} + +// AddToken adds the given token data to the context. +func AddToken(ctx context.Context, data *TokenData) context.Context { + return context.WithValue(ctx, tokenContextKey, data) +} + +// ReqData returns the request data from the context. It will panic if there is +// no data in the context (which should never happen in production). +func Token(ctx context.Context) *TokenData { + d, ok := ctx.Value(tokenContextKey).(*TokenData) + if !ok || d == nil { + // This should never happen in production, so making it a panic saves us a lot of unnecessary error handling. + panic(errors.New("context does not contain request context token data")) + } + return d +} + +// ExtractUserToken is a middleware that extracts the user token from the request +// and adds it to the request context. It also validates the token format. +func ExtractUserToken() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, token, err := parseAuthorizationHeader(r) + if err != nil { + // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec + if errors.Is(err, errMissingAuthorizationHeader) { + // sendAuthChallenge(w, r, cfg, obsv) + return + } + // For other auth errors (bad format, unsupported), return 400 + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Add token info to context + ctx := r.Context() + ctx = AddToken(ctx, &TokenData{Token: token}) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { + authHeader := req.Header.Get(httpheaders.AuthorizationHeader) + if authHeader == "" { + return 0, "", errMissingAuthorizationHeader + } + + switch { + // decrypt dotcom token and set it as token + case strings.HasPrefix(authHeader, "GitHub-Bearer "): + return 0, "", errUnsupportedAuthorizationHeader + default: + // support both "Bearer" and "bearer" to conform to api.github.com + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = authHeader[7:] + } else { + token = authHeader + } + } + + // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. + // ex: tid=1;exp=25145314523;chat=1: + if strings.Contains(token, ":") { + return authTypeIDE, token, nil + } + + for _, prefix := range supportedThirdPartyTokenPrefixes { + if strings.HasPrefix(token, prefix) { + return authTypeGhToken, token, nil + } + } + + matchesOldTokenPattern := oldPatternRegexp.MatchString(token) + if matchesOldTokenPattern { + return authTypeGhToken, token, nil + } + + return authTypeUnknown, "", errBadAuthorizationHeader +}