diff --git a/.vscode/launch.json b/.vscode/launch.json index cea7fd917..c4d4015fa 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -23,6 +23,16 @@ "program": "cmd/github-mcp-server/main.go", "args": ["stdio", "--read-only"], "console": "integratedTerminal", + }, + { + "name": "Launch http server", + "type": "go", + "request": "launch", + "mode": "auto", + "cwd": "${workspaceFolder}", + "program": "cmd/github-mcp-server/main.go", + "args": ["http"], + "console": "integratedTerminal", } ] } \ No newline at end of file diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index cfb68be4e..4350d0137 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/internal/ghmcp" "github.com/github/github-mcp-server/pkg/github" + ghhttp "github.com/github/github-mcp-server/pkg/http" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -88,6 +89,24 @@ var ( return ghmcp.RunStdioServer(stdioServerConfig) }, } + + httpCmd = &cobra.Command{ + Use: "http", + Short: "Start HTTP server", + Long: `Start an HTTP server that listens for MCP requests over HTTP.`, + RunE: func(_ *cobra.Command, _ []string) error { + httpConfig := ghhttp.HTTPServerConfig{ + Version: version, + Host: viper.GetString("host"), + ExportTranslations: viper.GetBool("export-translations"), + EnableCommandLogging: viper.GetBool("enable-command-logging"), + LogFilePath: viper.GetString("log-file"), + ContentWindowSize: viper.GetInt("content-window-size"), + } + + return ghhttp.RunHTTPServer(httpConfig) + }, + } ) func init() { @@ -126,6 +145,7 @@ func init() { // Add subcommands rootCmd.AddCommand(stdioCmd) + rootCmd.AddCommand(httpCmd) } func initConfig() { diff --git a/go.mod b/go.mod index 5322b47ec..f42130786 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-chi/chi/v5 v5.2.3 github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/go-querystring v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 25cbf7fa9..0a43d3f09 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 250f6b4cc..e73d6105c 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -6,7 +6,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "os" "os/signal" "strings" @@ -15,66 +14,18 @@ import ( "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/http/transport" "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" ) -type MCPServerConfig struct { - // Version of the server - Version string - - // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) - Host string - - // GitHub Token to authenticate with the GitHub API - Token string - - // EnabledToolsets is a list of toolsets to enable - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration - EnabledToolsets []string - - // EnabledTools is a list of specific tools to enable (additive to toolsets) - // When specified, these tools are registered in addition to any specified toolset tools - EnabledTools []string - - // EnabledFeatures is a list of feature flags that are enabled - // Items with FeatureFlagEnable matching an entry in this list will be available - EnabledFeatures []string - - // Whether to enable dynamic toolsets - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery - DynamicToolsets bool - - // ReadOnly indicates if we should only offer read-only tools - ReadOnly bool - - // Translator provides translated text for the server tooling - Translator translations.TranslationHelperFunc - - // Content window size - ContentWindowSize int - - // LockdownMode indicates if we should enable lockdown mode - LockdownMode bool - - // Logger is used for logging within the server - Logger *slog.Logger - // RepoAccessTTL overrides the default TTL for repository access cache entries. - RepoAccessTTL *time.Duration - - // TokenScopes contains the OAuth scopes available to the token. - // When non-nil, tools requiring scopes not in this list will be hidden. - // This is used for PAT scope filtering where we can't issue scope challenges. - TokenScopes []string -} - // githubClients holds all the GitHub API clients created for a server instance. type githubClients struct { rest *gogithub.Client @@ -85,25 +36,25 @@ type githubClients struct { } // createGitHubClients creates all the GitHub API clients needed by the server. -func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { +func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.ApiHost) (*githubClients, error) { // Construct REST client restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) - restClient.BaseURL = apiHost.baseRESTURL - restClient.UploadURL = apiHost.uploadURL + restClient.BaseURL = apiHost.BaseRESTURL + restClient.UploadURL = apiHost.UploadURL // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ - Transport: &bearerAuthTransport{ - transport: http.DefaultTransport, - token: cfg.Token, + Transport: &transport.BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: cfg.Token, }, } - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) + gqlClient := githubv4.NewEnterpriseClient(apiHost.GraphqlURL.String(), gqlHTTPClient) // Create raw content client (shares REST client's HTTP transport) - rawClient := raw.NewClient(restClient, apiHost.rawURL) + rawClient := raw.NewClient(restClient, apiHost.RawURL) // Set up repo access cache for lockdown mode var repoAccessCache *lockdown.RepoAccessCache @@ -126,35 +77,8 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, }, nil } -// resolveEnabledToolsets determines which toolsets should be enabled based on config. -// Returns nil for "use defaults", empty slice for "none", or explicit list. -func resolveEnabledToolsets(cfg MCPServerConfig) []string { - enabledToolsets := cfg.EnabledToolsets - - // In dynamic mode, remove "all" and "default" since users enable toolsets on demand - if cfg.DynamicToolsets && enabledToolsets != nil { - enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) - enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) - } - - if enabledToolsets != nil { - return enabledToolsets - } - if cfg.DynamicToolsets { - // Dynamic mode with no toolsets specified: start empty so users enable on demand - return []string{} - } - if len(cfg.EnabledTools) > 0 { - // When specific tools are requested but no toolsets, don't use default toolsets - // This matches the original behavior: --tools=X alone registers only X - return []string{} - } - // nil means "use defaults" in WithToolsets - return nil -} - -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) +func NewStdioMCPServer(cfg github.MCPServerConfig) (*mcp.Server, error) { + apiHost, err := utils.ParseAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } @@ -164,40 +88,6 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } - enabledToolsets := resolveEnabledToolsets(cfg) - - // For instruction generation, we need actual toolset names (not nil). - // nil means "use defaults" in inventory, so expand it for instructions. - instructionToolsets := enabledToolsets - if instructionToolsets == nil { - instructionToolsets = github.GetDefaultToolsetIDs() - } - - // Create the MCP server - serverOpts := &mcp.ServerOptions{ - Instructions: github.GenerateInstructions(instructionToolsets), - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { - return clients.rest, nil - }), - } - - // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts - // may be enabled at runtime even if none are registered initially. - if cfg.DynamicToolsets { - serverOpts.Capabilities = &mcp.ServerCapabilities{ - Tools: &mcp.ToolCapabilities{}, - Resources: &mcp.ResourceCapabilities{}, - Prompts: &mcp.PromptCapabilities{}, - } - } - - ghServer := github.NewServer(cfg.Version, serverOpts) - - // Add middlewares - ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) - // Create dependencies for tool handlers deps := github.NewBaseDeps( clients.rest, @@ -208,21 +98,13 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { github.FeatureFlags{LockdownMode: cfg.LockdownMode}, cfg.ContentWindowSize, ) - - // Inject dependencies into context for all tool handlers - ghServer.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - return next(github.ContextWithDeps(ctx, deps), method, req) - } - }) - // Build and register the tool/resource/prompt inventory inventoryBuilder := github.NewInventory(cfg.Translator). WithDeprecatedAliases(github.DeprecatedToolAliases). WithReadOnly(cfg.ReadOnly). - WithToolsets(enabledToolsets). - WithTools(github.CleanTools(cfg.EnabledTools)). - WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) + WithToolsets(cfg.EnabledToolsets). + WithTools(github.CleanTools(cfg.EnabledTools)) + // WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) // Apply token scope filtering if scopes are known (for PAT filtering) if cfg.TokenScopes != nil { @@ -231,52 +113,16 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { inventory := inventoryBuilder.Build() - if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { - fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) + ghServer, err := github.NewMCPServer(&cfg, deps, inventory) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub MCP server: %w", err) } - // Register GitHub tools/resources/prompts from the inventory. - // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets - // is empty - users enable toolsets at runtime via the dynamic tools below (but can - // enable toolsets or tools explicitly that do need registration). - inventory.RegisterAll(context.Background(), ghServer, deps) - - // Register dynamic toolset management tools (enable/disable) - these are separate - // meta-tools that control the inventory, not part of the inventory itself - if cfg.DynamicToolsets { - registerDynamicTools(ghServer, inventory, deps, cfg.Translator) - } + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) return ghServer, nil } -// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. -func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps *github.BaseDeps, t translations.TranslationHelperFunc) { - dynamicDeps := github.DynamicToolDependencies{ - Server: server, - Inventory: inventory, - ToolDeps: deps, - T: t, - } - for _, tool := range github.DynamicTools(inventory) { - tool.RegisterFunc(server, dynamicDeps) - } -} - -// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name -// is present in the provided list of enabled features. For the local server, -// this is populated from the --features CLI flag. -func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { - // Build a set for O(1) lookup - featureSet := make(map[string]bool, len(enabledFeatures)) - for _, f := range enabledFeatures { - featureSet[f] = true - } - return func(_ context.Context, flagName string) (bool, error) { - return featureSet[flagName], nil - } -} - type StdioServerConfig struct { // Version of the server Version string @@ -366,7 +212,7 @@ func RunStdioServer(cfg StdioServerConfig) error { logger.Debug("skipping scope filtering for non-PAT token") } - ghServer, err := NewMCPServer(MCPServerConfig{ + ghServer, err := NewStdioMCPServer(github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, Token: cfg.Token, @@ -427,220 +273,7 @@ func RunStdioServer(cfg StdioServerConfig) error { return nil } -type apiHost struct { - baseRESTURL *url.URL - graphqlURL *url.URL - uploadURL *url.URL - rawURL *url.URL -} - -func newDotcomHost() (apiHost, error) { - baseRestURL, err := url.Parse("https://api.github.com/") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err) - } - - gqlURL, err := url.Parse("https://api.github.com/graphql") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err) - } - - uploadURL, err := url.Parse("https://uploads.github.com") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err) - } - - rawURL, err := url.Parse("https://raw.githubusercontent.com/") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: baseRestURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -func newGHECHost(hostname string) (apiHost, error) { - u, err := url.Parse(hostname) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err) - } - - // Unsecured GHEC would be an error - if u.Scheme == "http" { - return apiHost{}, fmt.Errorf("GHEC URL must be HTTPS") - } - - restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err) - } - - gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err) - } - - uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err) - } - - rawURL, err := url.Parse(fmt.Sprintf("https://raw.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: restURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -func newGHESHost(hostname string) (apiHost, error) { - u, err := url.Parse(hostname) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES URL: %w", err) - } - - restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err) - } - - gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err) - } - - // Check if subdomain isolation is enabled - // See https://docs.github.com/en/enterprise-server@3.17/admin/configuring-settings/hardening-security-for-your-enterprise/enabling-subdomain-isolation#about-subdomain-isolation - hasSubdomainIsolation := checkSubdomainIsolation(u.Scheme, u.Hostname()) - - var uploadURL *url.URL - if hasSubdomainIsolation { - // With subdomain isolation: https://uploads.hostname/ - uploadURL, err = url.Parse(fmt.Sprintf("%s://uploads.%s/", u.Scheme, u.Hostname())) - } else { - // Without subdomain isolation: https://hostname/api/uploads/ - uploadURL, err = url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname())) - } - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err) - } - - var rawURL *url.URL - if hasSubdomainIsolation { - // With subdomain isolation: https://raw.hostname/ - rawURL, err = url.Parse(fmt.Sprintf("%s://raw.%s/", u.Scheme, u.Hostname())) - } else { - // Without subdomain isolation: https://hostname/raw/ - rawURL, err = url.Parse(fmt.Sprintf("%s://%s/raw/", u.Scheme, u.Hostname())) - } - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: restURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -// checkSubdomainIsolation detects if GitHub Enterprise Server has subdomain isolation enabled -// by attempting to ping the raw./_ping endpoint on the subdomain. The raw subdomain must always exist for subdomain isolation. -func checkSubdomainIsolation(scheme, hostname string) bool { - subdomainURL := fmt.Sprintf("%s://raw.%s/_ping", scheme, hostname) - - client := &http.Client{ - Timeout: 5 * time.Second, - // Don't follow redirects - we just want to check if the endpoint exists - //nolint:revive // parameters are required by http.Client.CheckRedirect signature - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - resp, err := client.Get(subdomainURL) - if err != nil { - return false - } - defer resp.Body.Close() - - return resp.StatusCode == http.StatusOK -} - -// Note that this does not handle ports yet, so development environments are out. -func parseAPIHost(s string) (apiHost, error) { - if s == "" { - return newDotcomHost() - } - - u, err := url.Parse(s) - if err != nil { - return apiHost{}, fmt.Errorf("could not parse host as URL: %s", s) - } - - if u.Scheme == "" { - return apiHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s) - } - - if strings.HasSuffix(u.Hostname(), "github.com") { - return newDotcomHost() - } - - if strings.HasSuffix(u.Hostname(), "ghe.com") { - return newGHECHost(s) - } - - return newGHESHost(s) -} - -type userAgentTransport struct { - transport http.RoundTripper - agent string -} - -func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("User-Agent", t.agent) - return t.transport.RoundTrip(req) -} - -type bearerAuthTransport struct { - transport http.RoundTripper - token string -} - -func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("Authorization", "Bearer "+t.token) - - // Check for GraphQL-Features in context and add header if present - if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 { - req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) - } - - return t.transport.RoundTrip(req) -} - -func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { - // Ensure the context is cleared of any previous errors - // as context isn't propagated through middleware - ctx = errors.ContextWithGitHubErrors(ctx) - return next(ctx, method, req) - } -} - -func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler { +func addUserAgentsMiddleware(cfg github.MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler { return func(next mcp.MethodHandler) mcp.MethodHandler { return func(ctx context.Context, method string, request mcp.Request) (result mcp.Result, err error) { if method != "initialize" { @@ -662,9 +295,9 @@ func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, g restClient.UserAgent = userAgent - gqlHTTPClient.Transport = &userAgentTransport{ - transport: gqlHTTPClient.Transport, - agent: userAgent, + gqlHTTPClient.Transport = &transport.UserAgentTransport{ + Transport: gqlHTTPClient.Transport, + Agent: userAgent, } return next(ctx, method, request) @@ -675,13 +308,13 @@ func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, g // fetchTokenScopesForHost fetches the OAuth scopes for a token from the GitHub API. // It constructs the appropriate API host URL based on the configured host. func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, error) { - apiHost, err := parseAPIHost(host) + apiHost, err := utils.ParseAPIHost(host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } fetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: apiHost.baseRESTURL.String(), + APIHost: apiHost.BaseRESTURL.String(), }) return fetcher.FetchTokenScopes(ctx, token) diff --git a/internal/ghmcp/server_test.go b/internal/ghmcp/server_test.go index 04c0989d4..6f0e3ac3f 100644 --- a/internal/ghmcp/server_test.go +++ b/internal/ghmcp/server_test.go @@ -1,112 +1 @@ package ghmcp - -import ( - "testing" - - "github.com/github/github-mcp-server/pkg/translations" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestNewMCPServer_CreatesSuccessfully verifies that the server can be created -// with the deps injection middleware properly configured. -func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { - t.Parallel() - - // Create a minimal server configuration - cfg := MCPServerConfig{ - Version: "test", - Host: "", // defaults to github.com - Token: "test-token", - EnabledToolsets: []string{"context"}, - ReadOnly: false, - Translator: translations.NullTranslationHelper, - ContentWindowSize: 5000, - LockdownMode: false, - } - - // Create the server - server, err := NewMCPServer(cfg) - require.NoError(t, err, "expected server creation to succeed") - require.NotNil(t, server, "expected server to be non-nil") - - // The fact that the server was created successfully indicates that: - // 1. The deps injection middleware is properly added - // 2. Tools can be registered without panicking - // - // If the middleware wasn't properly added, tool calls would panic with - // "ToolDependencies not found in context" when executed. - // - // The actual middleware functionality and tool execution with ContextWithDeps - // is already tested in pkg/github/*_test.go. -} - -// TestResolveEnabledToolsets verifies the toolset resolution logic. -func TestResolveEnabledToolsets(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - cfg MCPServerConfig - expectedResult []string - }{ - { - name: "nil toolsets without dynamic mode and no tools - use defaults", - cfg: MCPServerConfig{ - EnabledToolsets: nil, - DynamicToolsets: false, - EnabledTools: nil, - }, - expectedResult: nil, // nil means "use defaults" - }, - { - name: "nil toolsets with dynamic mode - start empty", - cfg: MCPServerConfig{ - EnabledToolsets: nil, - DynamicToolsets: true, - EnabledTools: nil, - }, - expectedResult: []string{}, // empty slice means no toolsets - }, - { - name: "explicit toolsets", - cfg: MCPServerConfig{ - EnabledToolsets: []string{"repos", "issues"}, - DynamicToolsets: false, - }, - expectedResult: []string{"repos", "issues"}, - }, - { - name: "empty toolsets - disable all", - cfg: MCPServerConfig{ - EnabledToolsets: []string{}, - DynamicToolsets: false, - }, - expectedResult: []string{}, // empty slice means no toolsets - }, - { - name: "specific tools without toolsets - no default toolsets", - cfg: MCPServerConfig{ - EnabledToolsets: nil, - DynamicToolsets: false, - EnabledTools: []string{"get_me"}, - }, - expectedResult: []string{}, // empty slice when tools specified but no toolsets - }, - { - name: "dynamic mode with explicit toolsets removes all and default", - cfg: MCPServerConfig{ - EnabledToolsets: []string{"all", "repos"}, - DynamicToolsets: true, - }, - expectedResult: []string{"repos"}, // "all" is removed in dynamic mode - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := resolveEnabledToolsets(tc.cfg) - assert.Equal(t, tc.expectedResult, result) - }) - } -} diff --git a/pkg/context/graphql_features.go b/pkg/context/graphql_features.go new file mode 100644 index 000000000..ebba3f757 --- /dev/null +++ b/pkg/context/graphql_features.go @@ -0,0 +1,19 @@ +package context + +import "context" + +// graphQLFeaturesKey is a context key for GraphQL feature flags +type graphQLFeaturesKey struct{} + +// withGraphQLFeatures adds GraphQL feature flags to the context +func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context { + return context.WithValue(ctx, graphQLFeaturesKey{}, features) +} + +// GetGraphQLFeatures retrieves GraphQL feature flags from the context +func GetGraphQLFeatures(ctx context.Context) []string { + if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok { + return features + } + return nil +} diff --git a/pkg/context/token.go b/pkg/context/token.go new file mode 100644 index 000000000..dd303f029 --- /dev/null +++ b/pkg/context/token.go @@ -0,0 +1,19 @@ +package context + +import "context" + +// tokenCtxKey is a context key for authentication token information +type tokenCtxKey struct{} + +// WithTokenInfo adds TokenInfo to the context +func WithTokenInfo(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, tokenCtxKey{}, token) +} + +// GetTokenInfo retrieves the authentication token from the context +func GetTokenInfo(ctx context.Context) (string, bool) { + if token, ok := ctx.Value(tokenCtxKey{}).(string); ok { + return token, true + } + return "", false +} diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index b41bf0b87..e5403d9b1 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -3,12 +3,17 @@ package github import ( "context" "errors" + "fmt" + "net/http" + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/transport" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" @@ -21,6 +26,14 @@ type depsContextKey struct{} // ErrDepsNotInContext is returned when ToolDependencies is not found in context. var ErrDepsNotInContext = errors.New("ToolDependencies not found in context; use ContextWithDeps to inject") +func InjectDepsMiddleware(deps ToolDependencies) mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + return next(ContextWithDeps(ctx, deps), method, req) + } + } +} + // ContextWithDeps returns a new context with the ToolDependencies stored in it. // This is used to inject dependencies at request time rather than at registration time, // avoiding expensive closure creation during server initialization. @@ -67,7 +80,7 @@ type ToolDependencies interface { GetRawClient(ctx context.Context) (*raw.Client, error) // GetRepoAccessCache returns the lockdown mode repo access cache - GetRepoAccessCache() *lockdown.RepoAccessCache + GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAccessCache, error) // GetT returns the translation helper function GetT() translations.TranslationHelperFunc @@ -132,7 +145,9 @@ func (d BaseDeps) GetRawClient(_ context.Context) (*raw.Client, error) { } // GetRepoAccessCache implements ToolDependencies. -func (d BaseDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } +func (d BaseDeps) GetRepoAccessCache(_ context.Context) (*lockdown.RepoAccessCache, error) { + return d.RepoAccessCache, nil +} // GetT implements ToolDependencies. func (d BaseDeps) GetT() translations.TranslationHelperFunc { return d.T } @@ -190,3 +205,126 @@ func NewToolFromHandler( st.AcceptedScopes = scopes.ExpandScopes(requiredScopes...) return st } + +type RequestDeps struct { + Client *gogithub.Client + GQLClient *githubv4.Client + RawClient *raw.Client + RepoAccessCache *lockdown.RepoAccessCache + + // Static dependencies + apiHosts *utils.ApiHost + version string + lockdownMode bool + RepoAccessOpts []lockdown.RepoAccessOption + T translations.TranslationHelperFunc + Flags FeatureFlags + ContentWindowSize int +} + +// NewRequestDeps creates a RequestDeps with the provided clients and configuration. +func NewRequestDeps( + apiHosts *utils.ApiHost, + version string, + lockdownMode bool, + repoAccessOpts []lockdown.RepoAccessOption, + t translations.TranslationHelperFunc, + flags FeatureFlags, + contentWindowSize int, +) *RequestDeps { + return &RequestDeps{ + apiHosts: apiHosts, + version: version, + lockdownMode: lockdownMode, + RepoAccessOpts: repoAccessOpts, + T: t, + Flags: flags, + ContentWindowSize: contentWindowSize, + } +} + +// GetClient implements ToolDependencies. +func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { + if d.Client != nil { + return d.Client, nil + } + + // extract the token from the context + token, _ := ghcontext.GetTokenInfo(ctx) + + // Construct REST client + restClient := gogithub.NewClient(nil).WithAuthToken(token) + restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", d.version) + restClient.BaseURL = d.apiHosts.BaseRESTURL + restClient.UploadURL = d.apiHosts.UploadURL + return restClient, nil +} + +// GetGQLClient implements ToolDependencies. +func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) { + if d.GQLClient != nil { + return d.GQLClient, nil + } + + // extract the token from the context + token, _ := ghcontext.GetTokenInfo(ctx) + + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host + gqlHTTPClient := &http.Client{ + Transport: &transport.BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: token, + }, + } + gqlClient := githubv4.NewEnterpriseClient(d.apiHosts.GraphqlURL.String(), gqlHTTPClient) + d.GQLClient = gqlClient + return gqlClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d *RequestDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { + if d.RawClient != nil { + return d.RawClient, nil + } + + client, err := d.GetClient(ctx) + if err != nil { + return nil, err + } + + rawClient := raw.NewClient(client, d.apiHosts.RawURL) + d.RawClient = rawClient + + return rawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d *RequestDeps) GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAccessCache, error) { + if !d.lockdownMode { + return nil, nil + } + + if d.RepoAccessCache != nil { + return d.RepoAccessCache, nil + } + + gqlClient, err := d.GetGQLClient(ctx) + if err != nil { + return nil, err + } + + // Create repo access cache + instance := lockdown.GetInstance(gqlClient, d.RepoAccessOpts...) + d.RepoAccessCache = instance + return instance, nil +} + +// GetT implements ToolDependencies. +func (d *RequestDeps) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d *RequestDeps) GetFlags() FeatureFlags { return d.Flags } + +// GetContentWindowSize implements ToolDependencies. +func (d *RequestDeps) GetContentWindowSize() int { return d.ContentWindowSize } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 63174c9e9..7e91d2076 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -9,9 +9,9 @@ import ( "strings" "time" + ghcontext "github.com/github/github-mcp-server/pkg/context" ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" - "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/scopes" @@ -312,13 +312,13 @@ Options are: switch method { case "get": - result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) + result, err := GetIssue(ctx, client, deps, owner, repo, issueNumber) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) + result, err := GetIssueComments(ctx, client, deps, owner, repo, issueNumber, pagination) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) + result, err := GetSubIssues(ctx, client, deps, owner, repo, issueNumber, pagination) return result, nil, err case "get_labels": result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) @@ -329,7 +329,13 @@ Options are: }) } -func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssue(ctx context.Context, client *github.Client, deps ToolDependencies, owner string, repo string, issueNumber int) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + flags := deps.GetFlags() + issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { return nil, fmt.Errorf("failed to get issue: %w", err) @@ -378,7 +384,13 @@ func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAc return utils.NewToolResultText(string(r)), nil } -func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssueComments(ctx context.Context, client *github.Client, deps ToolDependencies, owner string, repo string, issueNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + flags := deps.GetFlags() + opts := &github.IssueListCommentsOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -432,7 +444,13 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow return utils.NewToolResultText(string(r)), nil } -func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, featureFlags FeatureFlags) (*mcp.CallToolResult, error) { +func GetSubIssues(ctx context.Context, client *github.Client, deps ToolDependencies, owner string, repo string, issueNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + featureFlags := deps.GetFlags() + opts := &github.IssueListOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -1789,7 +1807,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server // Add the GraphQL-Features header for the agent assignment API // The header will be read by the HTTP transport if it's configured to do so - ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") + ctxWithFeatures := ghcontext.WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") if err := client.Mutate( ctxWithFeatures, @@ -1913,19 +1931,3 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser }, ) } - -// graphQLFeaturesKey is a context key for GraphQL feature flags -type graphQLFeaturesKey struct{} - -// withGraphQLFeatures adds GraphQL feature flags to the context -func withGraphQLFeatures(ctx context.Context, features ...string) context.Context { - return context.WithValue(ctx, graphQLFeaturesKey{}, features) -} - -// GetGraphQLFeatures retrieves GraphQL feature flags from the context -func GetGraphQLFeatures(ctx context.Context) []string { - if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok { - return features - } - return nil -} diff --git a/pkg/github/params.go b/pkg/github/params.go new file mode 100644 index 000000000..42803a392 --- /dev/null +++ b/pkg/github/params.go @@ -0,0 +1,393 @@ +package github + +import ( + "errors" + "fmt" + "strconv" + + "github.com/google/go-github/v79/github" + "github.com/google/jsonschema-go/jsonschema" +) + +// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. +// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. +func OptionalParamOK[T any, A map[string]any](args A, p string) (value T, ok bool, err error) { + // Check if the parameter is present in the request + val, exists := args[p] + if !exists { + // Not present, return zero value, false, no error + return + } + + // Check if the parameter is of the expected type + value, ok = val.(T) + if !ok { + // Present but wrong type + err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) + ok = true // Set ok to true because the parameter *was* present, even if wrong type + return + } + + // Present and correct type + ok = true + return +} + +// isAcceptedError checks if the error is an accepted error. +func isAcceptedError(err error) bool { + var acceptedError *github.AcceptedError + return errors.As(err, &acceptedError) +} + +// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredParam[T comparable](args map[string]any, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + val, ok := args[p].(T) + if !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + } + + if val == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + return val, nil +} + +// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredInt(args map[string]any, p string) (int, error) { + v, err := RequiredParam[float64](args, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type (float64). +// 3. Checks if the parameter is not empty, i.e: non-zero value. +// 4. Validates that the float64 value can be safely converted to int64 without truncation. +func RequiredBigInt(args map[string]any, p string) (int64, error) { + v, err := RequiredParam[float64](args, p) + if err != nil { + return 0, err + } + + result := int64(v) + // Check if converting back produces the same value to avoid silent truncation + if float64(result) != v { + return 0, fmt.Errorf("parameter %s value %f is too large to fit in int64", p, v) + } + return result, nil +} + +// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalParam[T any](args map[string]any, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return zero, nil + } + + // Check if the parameter is of the expected type + if _, ok := args[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, args[p]) + } + + return args[p].(T), nil +} + +// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalIntParam(args map[string]any, p string) (int, error) { + v, err := OptionalParam[float64](args, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func OptionalIntParamWithDefault(args map[string]any, p string, d int) (int, error) { + v, err := OptionalIntParam(args, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalBoolParam, but it also takes a default value. +func OptionalBoolParamWithDefault(args map[string]any, p string, d bool) (bool, error) { + _, ok := args[p] + v, err := OptionalParam[bool](args, p) + if err != nil { + return false, err + } + if !ok { + return d, nil + } + return v, nil +} + +// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a string +func OptionalStringArrayParam(args map[string]any, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return []string{}, nil + } + + switch v := args[p].(type) { + case nil: + return []string{}, nil + case []string: + return v, nil + case []any: + strSlice := make([]string, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + strSlice[i] = s + } + return strSlice, nil + default: + return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, args[p]) + } +} + +func convertStringSliceToBigIntSlice(s []string) ([]int64, error) { + int64Slice := make([]int64, len(s)) + for i, str := range s { + val, err := convertStringToBigInt(str, 0) + if err != nil { + return nil, fmt.Errorf("failed to convert element %d (%s) to int64: %w", i, str, err) + } + int64Slice[i] = val + } + return int64Slice, nil +} + +func convertStringToBigInt(s string, def int64) (int64, error) { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return def, fmt.Errorf("failed to convert string %s to int64: %w", s, err) + } + return v, nil +} + +// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns an empty slice +// 2. If it is present, iterates the elements, checks each is a string, and converts them to int64 values +func OptionalBigIntArrayParam(args map[string]any, p string) ([]int64, error) { + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return []int64{}, nil + } + + switch v := args[p].(type) { + case nil: + return []int64{}, nil + case []string: + return convertStringSliceToBigIntSlice(v) + case []any: + int64Slice := make([]int64, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + val, err := convertStringToBigInt(s, 0) + if err != nil { + return []int64{}, fmt.Errorf("parameter %s: failed to convert element %d (%s) to int64: %w", p, i, s, err) + } + int64Slice[i] = val + } + return int64Slice, nil + default: + return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, args[p]) + } +} + +// WithPagination adds REST API pagination parameters to a tool. +// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api +func WithPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["page"] = &jsonschema.Schema{ + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + } + + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + return schema +} + +// WithUnifiedPagination adds REST API pagination parameters to a tool. +// GraphQL tools will use this and convert page/perPage to GraphQL cursor parameters internally. +func WithUnifiedPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["page"] = &jsonschema.Schema{ + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + } + + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + schema.Properties["after"] = &jsonschema.Schema{ + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + } + + return schema +} + +// WithCursorPagination adds only cursor-based pagination parameters to a tool (no page parameter). +func WithCursorPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + schema.Properties["after"] = &jsonschema.Schema{ + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + } + + return schema +} + +type PaginationParams struct { + Page int + PerPage int + After string +} + +// OptionalPaginationParams returns the "page", "perPage", and "after" parameters from the request, +// or their default values if not present, "page" default is 1, "perPage" default is 30. +// In future, we may want to make the default values configurable, or even have this +// function returned from `withPagination`, where the defaults are provided alongside +// the min/max values. +func OptionalPaginationParams(args map[string]any) (PaginationParams, error) { + page, err := OptionalIntParamWithDefault(args, "page", 1) + if err != nil { + return PaginationParams{}, err + } + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return PaginationParams{}, err + } + after, err := OptionalParam[string](args, "after") + if err != nil { + return PaginationParams{}, err + } + return PaginationParams{ + Page: page, + PerPage: perPage, + After: after, + }, nil +} + +// OptionalCursorPaginationParams returns the "perPage" and "after" parameters from the request, +// without the "page" parameter, suitable for cursor-based pagination only. +func OptionalCursorPaginationParams(args map[string]any) (CursorPaginationParams, error) { + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return CursorPaginationParams{}, err + } + after, err := OptionalParam[string](args, "after") + if err != nil { + return CursorPaginationParams{}, err + } + return CursorPaginationParams{ + PerPage: perPage, + After: after, + }, nil +} + +type CursorPaginationParams struct { + PerPage int + After string +} + +// ToGraphQLParams converts cursor pagination parameters to GraphQL-specific parameters. +func (p CursorPaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { + if p.PerPage > 100 { + return nil, fmt.Errorf("perPage value %d exceeds maximum of 100", p.PerPage) + } + if p.PerPage < 0 { + return nil, fmt.Errorf("perPage value %d cannot be negative", p.PerPage) + } + first := int32(p.PerPage) + + var after *string + if p.After != "" { + after = &p.After + } + + return &GraphQLPaginationParams{ + First: &first, + After: after, + }, nil +} + +type GraphQLPaginationParams struct { + First *int32 + After *string +} + +// ToGraphQLParams converts REST API pagination parameters to GraphQL-specific parameters. +// This converts page/perPage to first parameter for GraphQL queries. +// If After is provided, it takes precedence over page-based pagination. +func (p PaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { + // Convert to CursorPaginationParams and delegate to avoid duplication + cursor := CursorPaginationParams{ + PerPage: p.PerPage, + After: p.After, + } + return cursor.ToGraphQLParams() +} diff --git a/pkg/github/params_test.go b/pkg/github/params_test.go new file mode 100644 index 000000000..9d7cfe432 --- /dev/null +++ b/pkg/github/params_test.go @@ -0,0 +1,503 @@ +package github + +import ( + "fmt" + "testing" + + "github.com/google/go-github/v79/github" + "github.com/stretchr/testify/assert" +) + +func Test_IsAcceptedError(t *testing.T) { + tests := []struct { + name string + err error + expectAccepted bool + }{ + { + name: "github AcceptedError", + err: &github.AcceptedError{}, + expectAccepted: true, + }, + { + name: "regular error", + err: fmt.Errorf("some other error"), + expectAccepted: false, + }, + { + name: "nil error", + err: nil, + expectAccepted: false, + }, + { + name: "wrapped AcceptedError", + err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), + expectAccepted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isAcceptedError(tc.err) + assert.Equal(t, tc.expectAccepted, result) + }) + } +} + +func Test_RequiredStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := RequiredParam[string](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalParam[string](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_RequiredInt(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := RequiredInt(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} +func Test_OptionalIntParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalIntParam(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParamWithDefault(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + defaultVal int + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + defaultVal: 10, + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + defaultVal: 10, + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalBooleanParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected bool + expectError bool + }{ + { + name: "true value", + params: map[string]interface{}{"flag": true}, + paramName: "flag", + expected: true, + expectError: false, + }, + { + name: "false value", + params: map[string]interface{}{"flag": false}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"flag": "not-a-boolean"}, + paramName: "flag", + expected: false, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalParam[bool](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestOptionalStringArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "flag", + expected: []string{}, + expectError: false, + }, + { + name: "valid any array parameter", + params: map[string]any{ + "flag": []any{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "valid string array parameter", + params: map[string]any{ + "flag": []string{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]any{ + "flag": 1, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + { + name: "wrong slice type parameter", + params: map[string]any{ + "flag": []any{"foo", 2}, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalStringArrayParam(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestOptionalPaginationParams(t *testing.T) { + tests := []struct { + name string + params map[string]any + expected PaginationParams + expectError bool + }{ + { + name: "no pagination parameters, default values", + params: map[string]any{}, + expected: PaginationParams{ + Page: 1, + PerPage: 30, + }, + expectError: false, + }, + { + name: "page parameter, default perPage", + params: map[string]any{ + "page": float64(2), + }, + expected: PaginationParams{ + Page: 2, + PerPage: 30, + }, + expectError: false, + }, + { + name: "perPage parameter, default page", + params: map[string]any{ + "perPage": float64(50), + }, + expected: PaginationParams{ + Page: 1, + PerPage: 50, + }, + expectError: false, + }, + { + name: "page and perPage parameters", + params: map[string]any{ + "page": float64(2), + "perPage": float64(50), + }, + expected: PaginationParams{ + Page: 2, + PerPage: 50, + }, + expectError: false, + }, + { + name: "invalid page parameter", + params: map[string]any{ + "page": "not-a-number", + }, + expected: PaginationParams{}, + expectError: true, + }, + { + name: "invalid perPage parameter", + params: map[string]any{ + "perPage": "not-a-number", + }, + expected: PaginationParams{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalPaginationParams(tc.params) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 62952783e..614e4496f 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -15,7 +15,6 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" - "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/scopes" @@ -101,7 +100,7 @@ Possible options: switch method { case "get": - result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) + result, err := GetPullRequest(ctx, client, deps, owner, repo, pullNumber) return result, nil, err case "get_diff": result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) @@ -121,13 +120,13 @@ Possible options: if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - result, err := GetPullRequestReviewComments(ctx, gqlClient, deps.GetRepoAccessCache(), owner, repo, pullNumber, cursorPagination, deps.GetFlags()) + result, err := GetPullRequestReviewComments(ctx, gqlClient, deps, owner, repo, pullNumber, cursorPagination) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) + result, err := GetPullRequestReviews(ctx, client, deps, owner, repo, pullNumber) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) + result, err := GetIssueComments(ctx, client, deps, owner, repo, pullNumber, pagination) return result, nil, err default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil @@ -135,7 +134,13 @@ Possible options: }) } -func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequest(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + ff := deps.GetFlags() + pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -340,7 +345,13 @@ type pageInfoFragment struct { EndCursor githubv4.String } -func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, pagination CursorPaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Client, deps ToolDependencies, owner, repo string, pullNumber int, pagination CursorPaginationParams) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + ff := deps.GetFlags() + // Convert pagination parameters to GraphQL format gqlParams, err := pagination.ToGraphQLParams() if err != nil { @@ -421,7 +432,13 @@ func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Clien return utils.NewToolResultText(string(r)), nil } -func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequestReviews(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + ff := deps.GetFlags() + reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, diff --git a/pkg/github/server.go b/pkg/github/server.go index 8248da58f..9ee0a59f4 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -3,433 +3,216 @@ package github import ( "context" "encoding/json" - "errors" "fmt" - "strconv" + "log/slog" + "os" "strings" + "time" + gherrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/octicons" + "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" - "github.com/google/go-github/v79/github" - "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) -// NewServer creates a new GitHub MCP server with the specified GH client and logger. +type MCPServerConfig struct { + // Version of the server + Version string -func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { - if opts == nil { - opts = &mcp.ServerOptions{} - } + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string - // Create a new MCP server - s := mcp.NewServer(&mcp.Implementation{ - Name: "github-mcp-server", - Title: "GitHub MCP Server", - Version: version, - Icons: octicons.Icons("mark-github"), - }, opts) + // GitHub Token to authenticate with the GitHub API + Token string - return s -} + // EnabledToolsets is a list of toolsets to enable + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration + EnabledToolsets []string -func CompletionsHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { - return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { - switch req.Params.Ref.Type { - case "ref/resource": - if strings.HasPrefix(req.Params.Ref.URI, "repo://") { - return RepositoryResourceCompletionHandler(getClient)(ctx, req) - } - return nil, fmt.Errorf("unsupported resource URI: %s", req.Params.Ref.URI) - case "ref/prompt": - return nil, nil - default: - return nil, fmt.Errorf("unsupported ref type: %s", req.Params.Ref.Type) - } - } -} + // EnabledTools is a list of specific tools to enable (additive to toolsets) + // When specified, these tools are registered in addition to any specified toolset tools + EnabledTools []string -// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. -// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. -func OptionalParamOK[T any, A map[string]any](args A, p string) (value T, ok bool, err error) { - // Check if the parameter is present in the request - val, exists := args[p] - if !exists { - // Not present, return zero value, false, no error - return - } + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string - // Check if the parameter is of the expected type - value, ok = val.(T) - if !ok { - // Present but wrong type - err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) - ok = true // Set ok to true because the parameter *was* present, even if wrong type - return - } + // Whether to enable dynamic toolsets + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery + DynamicToolsets bool - // Present and correct type - ok = true - return -} - -// isAcceptedError checks if the error is an accepted error. -func isAcceptedError(err error) bool { - var acceptedError *github.AcceptedError - return errors.As(err, &acceptedError) -} + // ReadOnly indicates if we should only offer read-only tools + ReadOnly bool -// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type. -// 3. Checks if the parameter is not empty, i.e: non-zero value -func RequiredParam[T comparable](args map[string]any, p string) (T, error) { - var zero T - - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return zero, fmt.Errorf("missing required parameter: %s", p) - } + // Translator provides translated text for the server tooling + Translator translations.TranslationHelperFunc - // Check if the parameter is of the expected type - val, ok := args[p].(T) - if !ok { - return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) - } + // Content window size + ContentWindowSize int - if val == zero { - return zero, fmt.Errorf("missing required parameter: %s", p) - } + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool - return val, nil -} + // Logger is used for logging within the server + Logger *slog.Logger + // RepoAccessTTL overrides the default TTL for repository access cache entries. + RepoAccessTTL *time.Duration -// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type. -// 3. Checks if the parameter is not empty, i.e: non-zero value -func RequiredInt(args map[string]any, p string) (int, error) { - v, err := RequiredParam[float64](args, p) - if err != nil { - return 0, err - } - return int(v), nil -} - -// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type (float64). -// 3. Checks if the parameter is not empty, i.e: non-zero value. -// 4. Validates that the float64 value can be safely converted to int64 without truncation. -func RequiredBigInt(args map[string]any, p string) (int64, error) { - v, err := RequiredParam[float64](args, p) - if err != nil { - return 0, err - } - - result := int64(v) - // Check if converting back produces the same value to avoid silent truncation - if float64(result) != v { - return 0, fmt.Errorf("parameter %s value %f is too large to fit in int64", p, v) - } - return result, nil + // TokenScopes contains the OAuth scopes available to the token. + // When non-nil, tools requiring scopes not in this list will be hidden. + // This is used for PAT scope filtering where we can't issue scope challenges. + TokenScopes []string } -// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalParam[T any](args map[string]any, p string) (T, error) { - var zero T +func NewMCPServer(cfg *MCPServerConfig, deps ToolDependencies, inventory *inventory.Inventory) (*mcp.Server, error) { + enabledToolsets := resolveEnabledToolsets(cfg) - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return zero, nil + // For instruction generation, we need actual toolset names (not nil). + // nil means "use defaults" in inventory, so expand it for instructions. + instructionToolsets := enabledToolsets + if instructionToolsets == nil { + instructionToolsets = GetDefaultToolsetIDs() } - // Check if the parameter is of the expected type - if _, ok := args[p].(T); !ok { - return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, args[p]) + // Create the MCP server + serverOpts := &mcp.ServerOptions{ + Instructions: GenerateInstructions(instructionToolsets), + Logger: cfg.Logger, + CompletionHandler: CompletionsHandler(deps.GetClient), } - return args[p].(T), nil -} - -// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalIntParam(args map[string]any, p string) (int, error) { - v, err := OptionalParam[float64](args, p) - if err != nil { - return 0, err + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.Capabilities = &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{}, + Resources: &mcp.ResourceCapabilities{}, + Prompts: &mcp.PromptCapabilities{}, + } } - return int(v), nil -} -// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalIntParam, but it also takes a default value. -func OptionalIntParamWithDefault(args map[string]any, p string, d int) (int, error) { - v, err := OptionalIntParam(args, p) - if err != nil { - return 0, err - } - if v == 0 { - return d, nil - } - return v, nil -} + ghServer := NewServer(cfg.Version, serverOpts) -// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalBoolParam, but it also takes a default value. -func OptionalBoolParamWithDefault(args map[string]any, p string, d bool) (bool, error) { - _, ok := args[p] - v, err := OptionalParam[bool](args, p) - if err != nil { - return false, err - } - if !ok { - return d, nil - } - return v, nil -} + // Add middlewares + ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) + ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps)) -// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, iterates the elements and checks each is a string -func OptionalStringArrayParam(args map[string]any, p string) ([]string, error) { - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return []string{}, nil + if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) } - switch v := args[p].(type) { - case nil: - return []string{}, nil - case []string: - return v, nil - case []any: - strSlice := make([]string, len(v)) - for i, v := range v { - s, ok := v.(string) - if !ok { - return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) - } - strSlice[i] = s - } - return strSlice, nil - default: - return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, args[p]) - } -} + // Register GitHub tools/resources/prompts from the inventory. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + inventory.RegisterAll(context.Background(), ghServer, deps) -func convertStringSliceToBigIntSlice(s []string) ([]int64, error) { - int64Slice := make([]int64, len(s)) - for i, str := range s { - val, err := convertStringToBigInt(str, 0) - if err != nil { - return nil, fmt.Errorf("failed to convert element %d (%s) to int64: %w", i, str, err) - } - int64Slice[i] = val + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the inventory, not part of the inventory itself + if cfg.DynamicToolsets { + registerDynamicTools(ghServer, inventory, deps, cfg.Translator) } - return int64Slice, nil -} -func convertStringToBigInt(s string, def int64) (int64, error) { - v, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return def, fmt.Errorf("failed to convert string %s to int64: %w", s, err) - } - return v, nil + return ghServer, nil } -// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns an empty slice -// 2. If it is present, iterates the elements, checks each is a string, and converts them to int64 values -func OptionalBigIntArrayParam(args map[string]any, p string) ([]int64, error) { - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return []int64{}, nil +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps ToolDependencies, t translations.TranslationHelperFunc) { + dynamicDeps := DynamicToolDependencies{ + Server: server, + Inventory: inventory, + ToolDeps: deps, + T: t, } - - switch v := args[p].(type) { - case nil: - return []int64{}, nil - case []string: - return convertStringSliceToBigIntSlice(v) - case []any: - int64Slice := make([]int64, len(v)) - for i, v := range v { - s, ok := v.(string) - if !ok { - return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) - } - val, err := convertStringToBigInt(s, 0) - if err != nil { - return []int64{}, fmt.Errorf("parameter %s: failed to convert element %d (%s) to int64: %w", p, i, s, err) - } - int64Slice[i] = val - } - return int64Slice, nil - default: - return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, args[p]) + for _, tool := range DynamicTools(inventory) { + tool.RegisterFunc(server, dynamicDeps) } } -// WithPagination adds REST API pagination parameters to a tool. -// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api -func WithPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["page"] = &jsonschema.Schema{ - Type: "number", - Description: "Page number for pagination (min 1)", - Minimum: jsonschema.Ptr(1.0), +// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name +// is present in the provided list of enabled features. For the local server, +// this is populated from the --features CLI flag. +func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { + // Build a set for O(1) lookup + featureSet := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + featureSet[f] = true } - - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), + return func(_ context.Context, flagName string) (bool, error) { + return featureSet[flagName], nil } - - return schema } -// WithUnifiedPagination adds REST API pagination parameters to a tool. -// GraphQL tools will use this and convert page/perPage to GraphQL cursor parameters internally. -func WithUnifiedPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["page"] = &jsonschema.Schema{ - Type: "number", - Description: "Page number for pagination (min 1)", - Minimum: jsonschema.Ptr(1.0), - } +// resolveEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func resolveEnabledToolsets(cfg *MCPServerConfig) []string { + enabledToolsets := cfg.EnabledToolsets - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if cfg.DynamicToolsets && enabledToolsets != nil { + enabledToolsets = RemoveToolset(enabledToolsets, string(ToolsetMetadataAll.ID)) + enabledToolsets = RemoveToolset(enabledToolsets, string(ToolsetMetadataDefault.ID)) } - schema.Properties["after"] = &jsonschema.Schema{ - Type: "string", - Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + if enabledToolsets != nil { + return enabledToolsets } - - return schema -} - -// WithCursorPagination adds only cursor-based pagination parameters to a tool (no page parameter). -func WithCursorPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), + if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} } - - schema.Properties["after"] = &jsonschema.Schema{ - Type: "string", - Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + if len(cfg.EnabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} } - return schema -} - -type PaginationParams struct { - Page int - PerPage int - After string + // nil means "use defaults" in WithToolsets + return nil } -// OptionalPaginationParams returns the "page", "perPage", and "after" parameters from the request, -// or their default values if not present, "page" default is 1, "perPage" default is 30. -// In future, we may want to make the default values configurable, or even have this -// function returned from `withPagination`, where the defaults are provided alongside -// the min/max values. -func OptionalPaginationParams(args map[string]any) (PaginationParams, error) { - page, err := OptionalIntParamWithDefault(args, "page", 1) - if err != nil { - return PaginationParams{}, err - } - perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) - if err != nil { - return PaginationParams{}, err - } - after, err := OptionalParam[string](args, "after") - if err != nil { - return PaginationParams{}, err +func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + // Ensure the context is cleared of any previous errors + // as context isn't propagated through middleware + ctx = gherrors.ContextWithGitHubErrors(ctx) + return next(ctx, method, req) } - return PaginationParams{ - Page: page, - PerPage: perPage, - After: after, - }, nil } -// OptionalCursorPaginationParams returns the "perPage" and "after" parameters from the request, -// without the "page" parameter, suitable for cursor-based pagination only. -func OptionalCursorPaginationParams(args map[string]any) (CursorPaginationParams, error) { - perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) - if err != nil { - return CursorPaginationParams{}, err - } - after, err := OptionalParam[string](args, "after") - if err != nil { - return CursorPaginationParams{}, err - } - return CursorPaginationParams{ - PerPage: perPage, - After: after, - }, nil -} - -type CursorPaginationParams struct { - PerPage int - After string -} - -// ToGraphQLParams converts cursor pagination parameters to GraphQL-specific parameters. -func (p CursorPaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { - if p.PerPage > 100 { - return nil, fmt.Errorf("perPage value %d exceeds maximum of 100", p.PerPage) - } - if p.PerPage < 0 { - return nil, fmt.Errorf("perPage value %d cannot be negative", p.PerPage) - } - first := int32(p.PerPage) - - var after *string - if p.After != "" { - after = &p.After +// NewServer creates a new GitHub MCP server with the specified GH client and logger. +func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { + if opts == nil { + opts = &mcp.ServerOptions{} } - return &GraphQLPaginationParams{ - First: &first, - After: after, - }, nil -} + // Create a new MCP server + s := mcp.NewServer(&mcp.Implementation{ + Name: "github-mcp-server", + Title: "GitHub MCP Server", + Version: version, + Icons: octicons.Icons("mark-github"), + }, opts) -type GraphQLPaginationParams struct { - First *int32 - After *string + return s } -// ToGraphQLParams converts REST API pagination parameters to GraphQL-specific parameters. -// This converts page/perPage to first parameter for GraphQL queries. -// If After is provided, it takes precedence over page-based pagination. -func (p PaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { - // Convert to CursorPaginationParams and delegate to avoid duplication - cursor := CursorPaginationParams{ - PerPage: p.PerPage, - After: p.After, +func CompletionsHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + switch req.Params.Ref.Type { + case "ref/resource": + if strings.HasPrefix(req.Params.Ref.URI, "repo://") { + return RepositoryResourceCompletionHandler(getClient)(ctx, req) + } + return nil, fmt.Errorf("unsupported resource URI: %s", req.Params.Ref.URI) + case "ref/prompt": + return nil, nil + default: + return nil, fmt.Errorf("unsupported ref type: %s", req.Params.Ref.Type) + } } - return cursor.ToGraphQLParams() } func MarshalledTextResult(v any) *mcp.CallToolResult { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index a59cd9a93..42377a802 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // stubDeps is a test helper that implements ToolDependencies with configurable behavior. @@ -51,10 +52,12 @@ func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { return nil, nil } -func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache } -func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } -func (s stubDeps) GetFlags() FeatureFlags { return s.flags } -func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } +func (s stubDeps) GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAccessCache, error) { + return s.repoAccessCache, nil +} +func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } +func (s stubDeps) GetFlags() FeatureFlags { return s.flags } +func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } // Helper functions to create stub client functions for error testing func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) { @@ -101,496 +104,107 @@ func badRequestHandler(msg string) http.HandlerFunc { } } -func Test_IsAcceptedError(t *testing.T) { - tests := []struct { - name string - err error - expectAccepted bool - }{ - { - name: "github AcceptedError", - err: &github.AcceptedError{}, - expectAccepted: true, - }, - { - name: "regular error", - err: fmt.Errorf("some other error"), - expectAccepted: false, - }, - { - name: "nil error", - err: nil, - expectAccepted: false, - }, - { - name: "wrapped AcceptedError", - err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), - expectAccepted: true, - }, - } +// TestNewMCPServer_CreatesSuccessfully verifies that the server can be created +// with the deps injection middleware properly configured. +func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { + t.Parallel() - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := isAcceptedError(tc.err) - assert.Equal(t, tc.expectAccepted, result) - }) + // Create a minimal server configuration + cfg := MCPServerConfig{ + Version: "test", + Host: "", // defaults to github.com + Token: "test-token", + EnabledToolsets: []string{"context"}, + ReadOnly: false, + Translator: translations.NullTranslationHelper, + ContentWindowSize: 5000, + LockdownMode: false, } -} -func Test_RequiredStringParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected string - expectError bool - }{ - { - name: "valid string parameter", - params: map[string]interface{}{"name": "test-value"}, - paramName: "name", - expected: "test-value", - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "name", - expected: "", - expectError: true, - }, - { - name: "empty string parameter", - params: map[string]interface{}{"name": ""}, - paramName: "name", - expected: "", - expectError: true, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"name": 123}, - paramName: "name", - expected: "", - expectError: true, - }, - } + deps := stubDeps{} - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := RequiredParam[string](tc.params, tc.paramName) + // Create the server + server, err := NewMCPServer(&cfg, deps) + require.NoError(t, err, "expected server creation to succeed") + require.NotNil(t, server, "expected server to be non-nil") - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } + // The fact that the server was created successfully indicates that: + // 1. The deps injection middleware is properly added + // 2. Tools can be registered without panicking + // + // If the middleware wasn't properly added, tool calls would panic with + // "ToolDependencies not found in context" when executed. + // + // The actual middleware functionality and tool execution with ContextWithDeps + // is already tested in pkg/github/*_test.go. } -func Test_OptionalStringParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected string - expectError bool - }{ - { - name: "valid string parameter", - params: map[string]interface{}{"name": "test-value"}, - paramName: "name", - expected: "test-value", - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "name", - expected: "", - expectError: false, - }, - { - name: "empty string parameter", - params: map[string]interface{}{"name": ""}, - paramName: "name", - expected: "", - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"name": 123}, - paramName: "name", - expected: "", - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalParam[string](tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_RequiredInt(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - expected: 0, - expectError: true, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := RequiredInt(tc.params, tc.paramName) +// TestResolveEnabledToolsets verifies the toolset resolution logic. +func TestResolveEnabledToolsets(t *testing.T) { + t.Parallel() - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} -func Test_OptionalIntParam(t *testing.T) { tests := []struct { - name string - params map[string]interface{} - paramName string - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - expected: 0, - expectError: false, - }, - { - name: "zero value", - params: map[string]interface{}{"count": float64(0)}, - paramName: "count", - expected: 0, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalIntParam(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalNumberParamWithDefault(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - defaultVal int - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - defaultVal: 10, - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - defaultVal: 10, - expected: 10, - expectError: false, - }, - { - name: "zero value", - params: map[string]interface{}{"count": float64(0)}, - paramName: "count", - defaultVal: 10, - expected: 10, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - defaultVal: 10, - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalBooleanParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected bool - expectError bool - }{ - { - name: "true value", - params: map[string]interface{}{"flag": true}, - paramName: "flag", - expected: true, - expectError: false, - }, - { - name: "false value", - params: map[string]interface{}{"flag": false}, - paramName: "flag", - expected: false, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "flag", - expected: false, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"flag": "not-a-boolean"}, - paramName: "flag", - expected: false, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalParam[bool](tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func TestOptionalStringArrayParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected []string - expectError bool - }{ - { - name: "parameter not in request", - params: map[string]any{}, - paramName: "flag", - expected: []string{}, - expectError: false, - }, - { - name: "valid any array parameter", - params: map[string]any{ - "flag": []any{"v1", "v2"}, - }, - paramName: "flag", - expected: []string{"v1", "v2"}, - expectError: false, - }, - { - name: "valid string array parameter", - params: map[string]any{ - "flag": []string{"v1", "v2"}, - }, - paramName: "flag", - expected: []string{"v1", "v2"}, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]any{ - "flag": 1, - }, - paramName: "flag", - expected: []string{}, - expectError: true, - }, - { - name: "wrong slice type parameter", - params: map[string]any{ - "flag": []any{"foo", 2}, - }, - paramName: "flag", - expected: []string{}, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalStringArrayParam(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func TestOptionalPaginationParams(t *testing.T) { - tests := []struct { - name string - params map[string]any - expected PaginationParams - expectError bool + name string + cfg MCPServerConfig + expectedResult []string }{ { - name: "no pagination parameters, default values", - params: map[string]any{}, - expected: PaginationParams{ - Page: 1, - PerPage: 30, + name: "nil toolsets without dynamic mode and no tools - use defaults", + cfg: MCPServerConfig{ + EnabledToolsets: nil, + DynamicToolsets: false, + EnabledTools: nil, }, - expectError: false, + expectedResult: nil, // nil means "use defaults" }, { - name: "page parameter, default perPage", - params: map[string]any{ - "page": float64(2), + name: "nil toolsets with dynamic mode - start empty", + cfg: MCPServerConfig{ + EnabledToolsets: nil, + DynamicToolsets: true, + EnabledTools: nil, }, - expected: PaginationParams{ - Page: 2, - PerPage: 30, - }, - expectError: false, + expectedResult: []string{}, // empty slice means no toolsets }, { - name: "perPage parameter, default page", - params: map[string]any{ - "perPage": float64(50), - }, - expected: PaginationParams{ - Page: 1, - PerPage: 50, + name: "explicit toolsets", + cfg: MCPServerConfig{ + EnabledToolsets: []string{"repos", "issues"}, + DynamicToolsets: false, }, - expectError: false, + expectedResult: []string{"repos", "issues"}, }, { - name: "page and perPage parameters", - params: map[string]any{ - "page": float64(2), - "perPage": float64(50), + name: "empty toolsets - disable all", + cfg: MCPServerConfig{ + EnabledToolsets: []string{}, + DynamicToolsets: false, }, - expected: PaginationParams{ - Page: 2, - PerPage: 50, - }, - expectError: false, + expectedResult: []string{}, // empty slice means no toolsets }, { - name: "invalid page parameter", - params: map[string]any{ - "page": "not-a-number", + name: "specific tools without toolsets - no default toolsets", + cfg: MCPServerConfig{ + EnabledToolsets: nil, + DynamicToolsets: false, + EnabledTools: []string{"get_me"}, }, - expected: PaginationParams{}, - expectError: true, + expectedResult: []string{}, // empty slice when tools specified but no toolsets }, { - name: "invalid perPage parameter", - params: map[string]any{ - "perPage": "not-a-number", + name: "dynamic mode with explicit toolsets removes all and default", + cfg: MCPServerConfig{ + EnabledToolsets: []string{"all", "repos"}, + DynamicToolsets: true, }, - expected: PaginationParams{}, - expectError: true, + expectedResult: []string{"repos"}, // "all" is removed in dynamic mode }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := OptionalPaginationParams(tc.params) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } + result := resolveEnabledToolsets(&tc.cfg) + assert.Equal(t, tc.expectedResult, result) }) } } diff --git a/pkg/http/handler.go b/pkg/http/handler.go new file mode 100644 index 000000000..aae6e456a --- /dev/null +++ b/pkg/http/handler.go @@ -0,0 +1,105 @@ +package http + +import ( + "log/slog" + "net/http" + "strings" + + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type InventoryFactoryFunc func(r *http.Request) *inventory.Inventory + +type HttpMcpHandler struct { + config *HTTPServerConfig + apiHosts utils.ApiHost + logger *slog.Logger + t translations.TranslationHelperFunc + repoAccessOpts []lockdown.RepoAccessOption + inventoryFactoryFunc InventoryFactoryFunc +} + +func NewHttpMcpHandler(cfg *HTTPServerConfig, + t translations.TranslationHelperFunc, + apiHosts *utils.ApiHost, + repoAccessOptions []lockdown.RepoAccessOption, + logger *slog.Logger, + inventoryFactory InventoryFactoryFunc) *HttpMcpHandler { + return &HttpMcpHandler{ + config: cfg, + apiHosts: *apiHosts, + logger: logger, + t: t, + repoAccessOpts: repoAccessOptions, + inventoryFactoryFunc: inventoryFactory, + } +} + +func (s *HttpMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Set up repo access cache for lockdown mode + deps := github.NewRequestDeps( + &s.apiHosts, + s.config.Version, + s.config.LockdownMode, + s.repoAccessOpts, + s.t, + github.FeatureFlags{ + LockdownMode: s.config.LockdownMode, + }, + s.config.ContentWindowSize, + ) + + inventory := s.inventoryFactoryFunc(r) + + ghServer, err := github.NewMCPServer(&github.MCPServerConfig{ + Version: s.config.Version, + Host: s.config.Host, + Translator: s.t, + ContentWindowSize: s.config.ContentWindowSize, + Logger: s.logger, + RepoAccessTTL: s.config.RepoAccessCacheTTL, + }, deps, inventory) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + + mcpHandler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return ghServer + }, &mcp.StreamableHTTPOptions{ + Stateless: true, + }) + + middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r) +} + +func DefaultInventoryFactory(cfg *HTTPServerConfig, t translations.TranslationHelperFunc) InventoryFactoryFunc { + return func(r *http.Request) *inventory.Inventory { + b := github.NewInventory(t).WithDeprecatedAliases(github.DeprecatedToolAliases) + b = InventoryFiltersForRequestHeaders(r, b) + return b.Build() + } +} + +func InventoryFiltersForRequestHeaders(r *http.Request, builder *inventory.Builder) *inventory.Builder { + if r.Header.Get("X-MCP-Readonly") != "" { + builder = builder.WithReadOnly(true) + } + + if toolsetsStr := r.Header.Get("X-MCP-Toolsets"); toolsetsStr != "" { + toolsets := strings.Split(toolsetsStr, ",") + builder = builder.WithToolsets(toolsets) + } + + if toolsStr := r.Header.Get("X-MCP-Tools"); toolsStr != "" { + tools := strings.Split(toolsStr, ",") + builder = builder.WithTools(github.CleanTools(tools)) + } + + return builder +} diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go new file mode 100644 index 000000000..7b47eeaf6 --- /dev/null +++ b/pkg/http/headers/headers.go @@ -0,0 +1,26 @@ +package headers + +const ( + // AuthorizationHeader is a standard HTTP Header. + AuthorizationHeader = "Authorization" + // ContentTypeHeader is a standard HTTP Header. + ContentTypeHeader = "Content-Type" + // AcceptHeader is a standard HTTP Header. + AcceptHeader = "Accept" + // UserAgentHeader is a standard HTTP Header. + UserAgentHeader = "User-Agent" + + // ContentTypeJSON is the standard MIME type for JSON. + ContentTypeJSON = "application/json" + // ContentTypeEventStream is the standard MIME type for Event Streams. + ContentTypeEventStream = "text/event-stream" + + // ForwardedForHeader is a standard HTTP Header used to forward the originating IP address of a client. + ForwardedForHeader = "X-Forwarded-For" + + // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. + RealIPHeader = "X-Real-IP" + + // RequestHmacHeader is used to authenticate requests to the Raw API. + RequestHmacHeader = "Request-Hmac" +) diff --git a/pkg/http/mark/mark.go b/pkg/http/mark/mark.go new file mode 100644 index 000000000..6df0eef7b --- /dev/null +++ b/pkg/http/mark/mark.go @@ -0,0 +1,65 @@ +// Package mark provides a mechnanism for tagging errors with a well-known error value. +package mark + +import "errors" + +// This list of errors is not exhaustive, but is a good starting point for most +// applications. Feel free to add more as needed, but don't go overboard. +// Remember, the specific types of errors are only important so far as someone +// calling your code might want to write logic to handle each type of error +// differently. +// +// Do not add application-specific errors to this list. Instead, just define +// your own package with your own application-specific errors, and use this +// package to mark errors with them. The errors in this package are not special, +// they're just plain old errors. +// +// Not all errors need to be marked. An error that is not marked should be +// treated as an unexpected error that cannot be handled by calling code. This +// is often the case for network errors or logic errors. +var ( + ErrNotFound = errors.New("not found") + ErrAlreadyExists = errors.New("already exists") + ErrBadRequest = errors.New("bad request") + ErrUnauthorized = errors.New("unauthorized") + ErrCancelled = errors.New("request cancelled") + ErrUnavailable = errors.New("unavailable") + ErrTimedout = errors.New("request timed out") + ErrTooLarge = errors.New("request is too large") + ErrTooManyRequests = errors.New("too many requests") + ErrForbidden = errors.New("forbidden") +) + +// With wraps err with another error that will return true from errors.Is and +// errors.As for both err and markErr, and anything either may wrap. +func With(err, markErr error) error { + if err == nil { + return nil + } + return marked{wrapped: err, mark: markErr} +} + +type marked struct { + wrapped error + mark error +} + +func (f marked) Is(target error) bool { + // if this is false, errors.Is will call unwrap and retry on the wrapped + // error. + return errors.Is(f.mark, target) +} + +func (f marked) As(target any) bool { + // if this is false, errors.As will call unwrap and retry on the wrapped + // error. + return errors.As(f.mark, target) +} + +func (f marked) Unwrap() error { + return f.wrapped +} + +func (f marked) Error() string { + return f.mark.Error() + ": " + f.wrapped.Error() +} diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go new file mode 100644 index 000000000..c2e5c6382 --- /dev/null +++ b/pkg/http/middleware/token.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + httpheaders "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/http/mark" +) + +type authType int + +const ( + authTypeUnknown authType = iota + authTypeIDE + authTypeGhToken +) + +var ( + errMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) + errBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) + errUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) + errMissingTokenInfoHeader = fmt.Errorf("%w: missing required token info header", mark.ErrBadRequest) +) + +var supportedThirdPartyTokenPrefixes = []string{ + "ghp_", // Personal access token (classic) + "github_pat_", // Fine-grained personal access token + "gho_", // OAuth access token + "ghu_", // User access token for a GitHub App + "ghs_", // Installation access token for a GitHub App (a.k.a. server-to-server token) +} + +// oldPatternRegexp is the regular expression for the old pattern of the token. +// Until 2021, GitHub API tokens did not have an identifiable prefix. They +// were 40 characters long and only contained the characters a-f and 0-9. +var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) + +func ExtractUserToken() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, token, err := parseAuthorizationHeader(r) + if err != nil { + // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec + if errors.Is(err, errMissingAuthorizationHeader) { + // sendAuthChallenge(w, r, cfg, obsv) + return + } + // For other auth errors (bad format, unsupported), return 400 + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + ctx := r.Context() + ctx = ghcontext.WithTokenInfo(ctx, token) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) + } +} +func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { + authHeader := req.Header.Get(httpheaders.AuthorizationHeader) + if authHeader == "" { + return 0, "", errMissingAuthorizationHeader + } + + switch { + // decrypt dotcom token and set it as token + case strings.HasPrefix(authHeader, "GitHub-Bearer "): + return 0, "", errUnsupportedAuthorizationHeader + default: + // support both "Bearer" and "bearer" to conform to api.github.com + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = authHeader[7:] + } else { + token = authHeader + } + } + + // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. + // ex: tid=1;exp=25145314523;chat=1: + if strings.Contains(token, ":") { + return authTypeIDE, token, nil + } + + for _, prefix := range supportedThirdPartyTokenPrefixes { + if strings.HasPrefix(token, prefix) { + return authTypeGhToken, token, nil + } + } + + matchesOldTokenPattern := oldPatternRegexp.MatchString(token) + if matchesOldTokenPattern { + return authTypeGhToken, token, nil + } + + return 0, "", errBadAuthorizationHeader +} diff --git a/pkg/http/server.go b/pkg/http/server.go new file mode 100644 index 000000000..5e3ca6fb4 --- /dev/null +++ b/pkg/http/server.go @@ -0,0 +1,114 @@ +package http + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/go-chi/chi/v5" +) + +type HTTPServerConfig struct { + // Version of the server + Version string + + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string + + // ExportTranslations indicates if we should export translations + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions + ExportTranslations bool + + // EnableCommandLogging indicates if we should log commands + EnableCommandLogging bool + + // Path to the log file if not stderr + LogFilePath string + + // Content window size + ContentWindowSize int + + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool + + // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. + RepoAccessCacheTTL *time.Duration +} + +func RunHTTPServer(cfg HTTPServerConfig) error { + // Create app context + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + t, dumpTranslations := translations.TranslationHelper() + + var slogHandler slog.Handler + var logOutput io.Writer + if cfg.LogFilePath != "" { + file, err := os.OpenFile(cfg.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + logOutput = file + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) + } else { + logOutput = os.Stderr + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) + } + logger := slog.New(slogHandler) + logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "lockdownEnabled", cfg.LockdownMode) + + apiHost, err := utils.ParseAPIHost(cfg.Host) + if err != nil { + return fmt.Errorf("failed to parse API host: %w", err) + } + + repoAccessOpts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(logger.With("component", "lockdown")), + } + if cfg.RepoAccessCacheTTL != nil { + repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessCacheTTL)) + } + + handler := NewHttpMcpHandler(&cfg, t, &apiHost, repoAccessOpts, logger, DefaultInventoryFactory(&cfg, t)) + + r := chi.NewRouter() + r.Mount("/", handler) + + httpSvr := http.Server{ + Addr: ":8082", + Handler: r, + } + + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + logger.Info("shutting down server") + if err := httpSvr.Shutdown(shutdownCtx); err != nil { + logger.Error("error during server shutdown", "error", err) + } + }() + + if cfg.ExportTranslations { + // Once server is initialized, all translations are loaded + dumpTranslations() + } + + logger.Info("HTTP server listening on :8082") + if err := httpSvr.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("HTTP server error: %w", err) + } + + logger.Info("server stopped gracefully") + return nil +} diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go new file mode 100644 index 000000000..7e8a698eb --- /dev/null +++ b/pkg/http/transport/bearer.go @@ -0,0 +1,25 @@ +package transport + +import ( + "net/http" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" +) + +type BearerAuthTransport struct { + Transport http.RoundTripper + Token string +} + +func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+t.Token) + + // Check for GraphQL-Features in context and add header if present + if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { + req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) + } + + return t.Transport.RoundTrip(req) +} diff --git a/pkg/http/transport/user_agent.go b/pkg/http/transport/user_agent.go new file mode 100644 index 000000000..f1cbaeb4d --- /dev/null +++ b/pkg/http/transport/user_agent.go @@ -0,0 +1,14 @@ +package transport + +import "net/http" + +type UserAgentTransport struct { + Transport http.RoundTripper + Agent string +} + +func (t *UserAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("User-Agent", t.Agent) + return t.Transport.RoundTrip(req) +} diff --git a/pkg/utils/api.go b/pkg/utils/api.go new file mode 100644 index 000000000..5abbf02eb --- /dev/null +++ b/pkg/utils/api.go @@ -0,0 +1,185 @@ +package utils + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "time" +) + +type ApiHost struct { + BaseRESTURL *url.URL + GraphqlURL *url.URL + UploadURL *url.URL + RawURL *url.URL +} + +func newDotcomHost() (ApiHost, error) { + baseRestURL, err := url.Parse("https://api.github.com/") + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err) + } + + gqlURL, err := url.Parse("https://api.github.com/graphql") + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err) + } + + uploadURL, err := url.Parse("https://uploads.github.com") + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err) + } + + rawURL, err := url.Parse("https://raw.githubusercontent.com/") + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) + } + + return ApiHost{ + BaseRESTURL: baseRestURL, + GraphqlURL: gqlURL, + UploadURL: uploadURL, + RawURL: rawURL, + }, nil +} + +func newGHECHost(hostname string) (ApiHost, error) { + u, err := url.Parse(hostname) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err) + } + + // Unsecured GHEC would be an error + if u.Scheme == "http" { + return ApiHost{}, fmt.Errorf("GHEC URL must be HTTPS") + } + + restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname())) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err) + } + + gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname())) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err) + } + + uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s", u.Hostname())) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err) + } + + rawURL, err := url.Parse(fmt.Sprintf("https://raw.%s/", u.Hostname())) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) + } + + return ApiHost{ + BaseRESTURL: restURL, + GraphqlURL: gqlURL, + UploadURL: uploadURL, + RawURL: rawURL, + }, nil +} + +func newGHESHost(hostname string) (ApiHost, error) { + u, err := url.Parse(hostname) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHES URL: %w", err) + } + + restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname())) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err) + } + + gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname())) + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err) + } + + // Check if subdomain isolation is enabled + // See https://docs.github.com/en/enterprise-server@3.17/admin/configuring-settings/hardening-security-for-your-enterprise/enabling-subdomain-isolation#about-subdomain-isolation + hasSubdomainIsolation := checkSubdomainIsolation(u.Scheme, u.Hostname()) + + var uploadURL *url.URL + if hasSubdomainIsolation { + // With subdomain isolation: https://uploads.hostname/ + uploadURL, err = url.Parse(fmt.Sprintf("%s://uploads.%s/", u.Scheme, u.Hostname())) + } else { + // Without subdomain isolation: https://hostname/api/uploads/ + uploadURL, err = url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname())) + } + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err) + } + + var rawURL *url.URL + if hasSubdomainIsolation { + // With subdomain isolation: https://raw.hostname/ + rawURL, err = url.Parse(fmt.Sprintf("%s://raw.%s/", u.Scheme, u.Hostname())) + } else { + // Without subdomain isolation: https://hostname/raw/ + rawURL, err = url.Parse(fmt.Sprintf("%s://%s/raw/", u.Scheme, u.Hostname())) + } + if err != nil { + return ApiHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) + } + + return ApiHost{ + BaseRESTURL: restURL, + GraphqlURL: gqlURL, + UploadURL: uploadURL, + RawURL: rawURL, + }, nil +} + +// checkSubdomainIsolation detects if GitHub Enterprise Server has subdomain isolation enabled +// by attempting to ping the raw./_ping endpoint on the subdomain. The raw subdomain must always exist for subdomain isolation. +func checkSubdomainIsolation(scheme, hostname string) bool { + subdomainURL := fmt.Sprintf("%s://raw.%s/_ping", scheme, hostname) + + client := &http.Client{ + Timeout: 5 * time.Second, + // Don't follow redirects - we just want to check if the endpoint exists + //nolint:revive // parameters are required by http.Client.CheckRedirect signature + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(subdomainURL) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// Note that this does not handle ports yet, so development environments are out. +func ParseAPIHost(s string) (ApiHost, error) { + if s == "" { + return newDotcomHost() + } + + u, err := url.Parse(s) + if err != nil { + return ApiHost{}, fmt.Errorf("could not parse host as URL: %s", s) + } + + if u.Scheme == "" { + return ApiHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s) + } + + if strings.HasSuffix(u.Hostname(), "github.com") { + return newDotcomHost() + } + + if strings.HasSuffix(u.Hostname(), "ghe.com") { + return newGHECHost(s) + } + + return newGHESHost(s) +}