Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sensor/common/centralproxy/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package centralproxy

import (
"context"
"crypto/x509"
"fmt"
"net/http"
Expand All @@ -23,6 +24,22 @@ import (
"k8s.io/client-go/kubernetes"
)

// consoleUserKeyType is the context key for passing the authenticated console username.
type consoleUserKeyType struct{}

var consoleUserKey consoleUserKeyType

// contextWithConsoleUser returns a new context with the console username attached.
func contextWithConsoleUser(ctx context.Context, username string) context.Context {
return context.WithValue(ctx, consoleUserKey, username)
}

// consoleUserFromContext extracts the console username from the context, or returns empty string.
func consoleUserFromContext(ctx context.Context) string {
username, _ := ctx.Value(consoleUserKey).(string)
return username
}

var (
log = logging.LoggerForModule()
k8sClientQPS = 50.0
Expand Down Expand Up @@ -189,6 +206,9 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
return
}

// Pass the authenticated console user to the transport layer via context.
request = request.WithContext(contextWithConsoleUser(request.Context(), userInfo.Username))

if err := h.authorizer.authorize(request.Context(), userInfo, request); err != nil {
result = requestResultAuthzError
http.Error(writer, err.Error(), pkghttputil.StatusFromError(err))
Expand Down
31 changes: 31 additions & 0 deletions sensor/common/centralproxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,37 @@ func TestServeHTTP_RequiresAuthentication(t *testing.T) {
})
}

func TestServeHTTP_ConsoleUserContext(t *testing.T) {
t.Run("authenticated username is set in request context", func(t *testing.T) {
setupCentralCapsForTest(t)

var capturedUsername string
mockTransport := pkghttputil.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
capturedUsername = consoleUserFromContext(req.Context())
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}, nil
})

baseURL, err := url.Parse("https://central:443")
require.NoError(t, err)

h := newTestHandler(t, baseURL, mockTransport, newAllowingAuthorizer(t), "test-token")
h.centralReachable.Store(true)

req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil)
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()

h.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "test-user", capturedUsername, "console user should be set from authenticated username")
})
}

func TestServeHTTP_PathFiltering(t *testing.T) {
t.Run("disallowed path returns 403", func(t *testing.T) {
f := newProxyTestFixture(t, newAllowingAuthorizer(t))
Expand Down
79 changes: 50 additions & 29 deletions sensor/common/centralproxy/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ func (t *scopedTokenTransport) SetClient(conn grpc.ClientConnInterface) {
// is retried once with a fresh token.
func (t *scopedTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) {
scope := req.Header.Get(stackroxNamespaceHeader)
username := consoleUserFromContext(req.Context())
key := tokenCacheKey{scope: scope, username: username}

// Buffer the request body upfront so we can replay it on retry.
var bodyBytes []byte
Expand All @@ -131,7 +133,7 @@ func (t *scopedTokenTransport) RoundTrip(req *http.Request) (*http.Response, err
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

resp, err := t.doRoundTrip(req, scope)
resp, err := t.doRoundTrip(req, key)
if err != nil {
return nil, err
}
Expand All @@ -150,43 +152,59 @@ func (t *scopedTokenTransport) RoundTrip(req *http.Request) (*http.Response, err
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

t.tokenProvider.invalidateToken(scope)
return t.doRoundTrip(req, scope)
t.tokenProvider.invalidateToken(key)
return t.doRoundTrip(req, key)
}

return resp, nil
}

// doRoundTrip performs a single round trip with token injection.
func (t *scopedTokenTransport) doRoundTrip(req *http.Request, scope string) (*http.Response, error) {
token, err := t.tokenProvider.getTokenForScope(req.Context(), scope)
func (t *scopedTokenTransport) doRoundTrip(req *http.Request, key tokenCacheKey) (*http.Response, error) {
token, err := t.tokenProvider.getTokenForScope(req.Context(), key)
if err != nil {
return nil, errors.Wrap(err, "obtaining authorization token")
}

// Clone the request to avoid modifying the original.
reqCopy := req.Clone(req.Context())
reqCopy.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
// Strip the internal namespace scope header before forwarding to Central.
reqCopy.Header.Del(stackroxNamespaceHeader)

return t.base.RoundTrip(reqCopy) //nolint:wrapcheck
}

// tokenCacheKey combines namespace scope and username for cache isolation.
// Different users requesting the same scope get separate cache entries and tokens.
type tokenCacheKey struct {
scope string
username string
}

// cacheKeyString returns a string representation for coalescer keys.
// Components are URL-escaped so the "|" delimiter cannot appear in encoded values,
// preventing collisions even if scope or username formats change.
func (k tokenCacheKey) cacheKeyString() string {
return url.QueryEscape(k.scope) + "|" + url.QueryEscape(k.username)
}

// tokenProvider manages dynamic token acquisition from Central.
type tokenProvider struct {
// client holds the Central gRPC client. It is stored atomically to safeguard
// against data races.
client atomic.Pointer[centralv1.TokenServiceClient]
clusterIDGetter clusterIDGetter
tokenCache expiringcache.Cache[string, string]
// tokenGroup coalesces concurrent token requests for the same namespace scope.
tokenCache expiringcache.Cache[tokenCacheKey, string]
// tokenGroup coalesces concurrent token requests for the same namespace scope and user.
tokenGroup *coalescer.Coalescer[string]
}

// newTokenProvider creates a new tokenProvider.
func newTokenProvider(clusterIDGetter clusterIDGetter) *tokenProvider {
return &tokenProvider{
clusterIDGetter: clusterIDGetter,
tokenCache: expiringcache.NewExpiringCache[string, string](tokenCacheTTL),
tokenCache: expiringcache.NewExpiringCache[tokenCacheKey, string](tokenCacheTTL),
tokenGroup: coalescer.New[string](),
}
}
Expand All @@ -201,84 +219,86 @@ func (p *tokenProvider) setClient(conn grpc.ClientConnInterface) {
p.client.Store(&client)
}

// getTokenForScope returns a token for the given namespace scope.
// getTokenForScope returns a token for the given namespace scope and console user.
// Scope values:
// - "" (empty): Token with empty access scope (authentication only)
// - "<namespace>": Token scoped to the specific namespace
// - FullClusterAccessScope ("*"): Token with full cluster access
//
// Concurrent requests for the same scope are coalesced to reduce load on Central.
func (p *tokenProvider) getTokenForScope(ctx context.Context, namespaceScope string) (string, error) {
// Concurrent requests for the same scope and user are coalesced to reduce load on Central.
func (p *tokenProvider) getTokenForScope(ctx context.Context, key tokenCacheKey) (string, error) {
client := p.client.Load()
if client == nil {
incrementTokenRequest(tokenResultError)
return "", errors.Wrap(errServiceUnavailable, "token provider not initialized: central connection not available")
}

// Fast path: check cache first.
if token, ok := p.tokenCache.Get(namespaceScope); ok {
if token, ok := p.tokenCache.Get(key); ok {
incrementTokenRequest(tokenResultCacheHit)
return token, nil
}

// Slow path: coalesce concurrent requests for the same scope.
return p.tokenGroup.Coalesce(ctx, namespaceScope, func() (string, error) { //nolint:wrapcheck
coalescerKey := key.cacheKeyString()

// Slow path: coalesce concurrent requests for the same scope and user.
return p.tokenGroup.Coalesce(ctx, coalescerKey, func() (string, error) { //nolint:wrapcheck
// Double-check cache inside coalesce to avoid redundant API calls.
if token, ok := p.tokenCache.Get(namespaceScope); ok {
if token, ok := p.tokenCache.Get(key); ok {
incrementTokenRequest(tokenResultCacheHit)
return token, nil
}

log.Debugf("Token cache miss for namespace scope %q, requesting from Central", namespaceScope)
log.Debugf("Token cache miss for namespace scope %q, requesting from Central", key.scope)

// Use a background context with timeout to ensure the shared function is independent
// of the initial request context while still having a bounded lifetime.
ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout)
defer cancel()
token, err := p.requestToken(ctx, *client, namespaceScope)
token, err := p.requestToken(ctx, *client, key)
if err != nil {
incrementTokenRequest(tokenResultError)
return "", err
}

p.tokenCache.Add(namespaceScope, token)
p.tokenCache.Add(key, token)
incrementTokenRequest(tokenResultSuccess)
return token, nil
})
}

// requestToken performs the RPC call to Central to generate a token for the given scope.
func (p *tokenProvider) requestToken(ctx context.Context, client centralv1.TokenServiceClient, namespaceScope string) (string, error) {
req, err := p.buildTokenRequest(namespaceScope)
// requestToken performs the RPC call to Central to generate a token for the given scope and user.
func (p *tokenProvider) requestToken(ctx context.Context, client centralv1.TokenServiceClient, key tokenCacheKey) (string, error) {
req, err := p.buildTokenRequest(key.scope, key.username)
if err != nil {
return "", errors.Wrap(err, "building token request")
}
resp, err := client.GenerateTokenForPermissionsAndScope(ctx, req)
if err != nil {
return "", errors.Wrapf(err, "requesting token from Central for scope %q", namespaceScope)
return "", errors.Wrapf(err, "requesting token from Central for scope %q", key.scope)
}

token := resp.GetToken()
if token == "" {
return "", errors.Errorf("received empty token from Central for scope %q", namespaceScope)
return "", errors.Errorf("received empty token from Central for scope %q", key.scope)
}

return token, nil
}

// invalidateToken removes the cached token for the given scope.
// invalidateToken removes the cached token for the given scope and user.
// It also removes the coalescer key so subsequent callers will
// trigger a fresh token request rather than joining any in-progress request.
// Note: This does not cancel already running requests; they will complete
// normally but their results will not be used by new callers.
func (p *tokenProvider) invalidateToken(scope string) {
p.tokenCache.Remove(scope)
p.tokenGroup.Forget(scope)
func (p *tokenProvider) invalidateToken(key tokenCacheKey) {
p.tokenCache.Remove(key)
p.tokenGroup.Forget(key.cacheKeyString())
}

// buildTokenRequest creates the token request based on the namespace scope.
// buildTokenRequest creates the token request based on the namespace scope and username.
// Returns an error if the cluster ID is not available yet.
func (p *tokenProvider) buildTokenRequest(namespaceScope string) (*centralv1.GenerateTokenForPermissionsAndScopeRequest, error) {
func (p *tokenProvider) buildTokenRequest(namespaceScope string, username string) (*centralv1.GenerateTokenForPermissionsAndScopeRequest, error) {
clusterID := p.clusterIDGetter.GetNoWait()
if clusterID == "" {
return nil, errors.Wrap(errServiceUnavailable, "cluster ID not available")
Expand All @@ -287,6 +307,7 @@ func (p *tokenProvider) buildTokenRequest(namespaceScope string) (*centralv1.Gen
req := &centralv1.GenerateTokenForPermissionsAndScopeRequest{
Permissions: tokenPermissions,
Lifetime: durationpb.New(tokenTTL),
Requester: username,
}

switch namespaceScope {
Expand Down
Loading
Loading