From 56e55bd6e5e83bc3cb90a14a67d1ea197258a911 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stephan=20He=C3=9Felmann?= Date: Mon, 30 Mar 2026 10:34:46 +0200 Subject: [PATCH] ROX-33256: add console user to requester claim --- sensor/common/centralproxy/handler.go | 20 ++ sensor/common/centralproxy/handler_test.go | 31 ++ sensor/common/centralproxy/transport.go | 79 ++-- sensor/common/centralproxy/transport_test.go | 360 ++++++++++++++----- 4 files changed, 379 insertions(+), 111 deletions(-) diff --git a/sensor/common/centralproxy/handler.go b/sensor/common/centralproxy/handler.go index 29c39af274d27..b9d5bb92c0980 100644 --- a/sensor/common/centralproxy/handler.go +++ b/sensor/common/centralproxy/handler.go @@ -1,6 +1,7 @@ package centralproxy import ( + "context" "crypto/x509" "fmt" "net/http" @@ -23,6 +24,22 @@ import ( "k8s.io/client-go/kubernetes" ) +// consoleUserKeyType is the context key for passing the authenticated console username. +type consoleUserKeyType struct{} + +var consoleUserKey consoleUserKeyType + +// contextWithConsoleUser returns a new context with the console username attached. +func contextWithConsoleUser(ctx context.Context, username string) context.Context { + return context.WithValue(ctx, consoleUserKey, username) +} + +// consoleUserFromContext extracts the console username from the context, or returns empty string. +func consoleUserFromContext(ctx context.Context) string { + username, _ := ctx.Value(consoleUserKey).(string) + return username +} + var ( log = logging.LoggerForModule() k8sClientQPS = 50.0 @@ -189,6 +206,9 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { return } + // Pass the authenticated console user to the transport layer via context. + request = request.WithContext(contextWithConsoleUser(request.Context(), userInfo.Username)) + if err := h.authorizer.authorize(request.Context(), userInfo, request); err != nil { result = requestResultAuthzError http.Error(writer, err.Error(), pkghttputil.StatusFromError(err)) diff --git a/sensor/common/centralproxy/handler_test.go b/sensor/common/centralproxy/handler_test.go index 74f42c8fafdfe..9f246f8c953dc 100644 --- a/sensor/common/centralproxy/handler_test.go +++ b/sensor/common/centralproxy/handler_test.go @@ -555,6 +555,37 @@ func TestServeHTTP_RequiresAuthentication(t *testing.T) { }) } +func TestServeHTTP_ConsoleUserContext(t *testing.T) { + t.Run("authenticated username is set in request context", func(t *testing.T) { + setupCentralCapsForTest(t) + + var capturedUsername string + mockTransport := pkghttputil.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedUsername = consoleUserFromContext(req.Context()) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }) + + baseURL, err := url.Parse("https://central:443") + require.NoError(t, err) + + h := newTestHandler(t, baseURL, mockTransport, newAllowingAuthorizer(t), "test-token") + h.centralReachable.Store(true) + + req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "test-user", capturedUsername, "console user should be set from authenticated username") + }) +} + func TestServeHTTP_PathFiltering(t *testing.T) { t.Run("disallowed path returns 403", func(t *testing.T) { f := newProxyTestFixture(t, newAllowingAuthorizer(t)) diff --git a/sensor/common/centralproxy/transport.go b/sensor/common/centralproxy/transport.go index 4de30ec4a84d8..b26011fa812a6 100644 --- a/sensor/common/centralproxy/transport.go +++ b/sensor/common/centralproxy/transport.go @@ -119,6 +119,8 @@ func (t *scopedTokenTransport) SetClient(conn grpc.ClientConnInterface) { // is retried once with a fresh token. func (t *scopedTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { scope := req.Header.Get(stackroxNamespaceHeader) + username := consoleUserFromContext(req.Context()) + key := tokenCacheKey{scope: scope, username: username} // Buffer the request body upfront so we can replay it on retry. var bodyBytes []byte @@ -131,7 +133,7 @@ func (t *scopedTokenTransport) RoundTrip(req *http.Request) (*http.Response, err req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } - resp, err := t.doRoundTrip(req, scope) + resp, err := t.doRoundTrip(req, key) if err != nil { return nil, err } @@ -150,16 +152,16 @@ func (t *scopedTokenTransport) RoundTrip(req *http.Request) (*http.Response, err req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } - t.tokenProvider.invalidateToken(scope) - return t.doRoundTrip(req, scope) + t.tokenProvider.invalidateToken(key) + return t.doRoundTrip(req, key) } return resp, nil } // doRoundTrip performs a single round trip with token injection. -func (t *scopedTokenTransport) doRoundTrip(req *http.Request, scope string) (*http.Response, error) { - token, err := t.tokenProvider.getTokenForScope(req.Context(), scope) +func (t *scopedTokenTransport) doRoundTrip(req *http.Request, key tokenCacheKey) (*http.Response, error) { + token, err := t.tokenProvider.getTokenForScope(req.Context(), key) if err != nil { return nil, errors.Wrap(err, "obtaining authorization token") } @@ -167,18 +169,34 @@ func (t *scopedTokenTransport) doRoundTrip(req *http.Request, scope string) (*ht // Clone the request to avoid modifying the original. reqCopy := req.Clone(req.Context()) reqCopy.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + // Strip the internal namespace scope header before forwarding to Central. + reqCopy.Header.Del(stackroxNamespaceHeader) return t.base.RoundTrip(reqCopy) //nolint:wrapcheck } +// tokenCacheKey combines namespace scope and username for cache isolation. +// Different users requesting the same scope get separate cache entries and tokens. +type tokenCacheKey struct { + scope string + username string +} + +// cacheKeyString returns a string representation for coalescer keys. +// Components are URL-escaped so the "|" delimiter cannot appear in encoded values, +// preventing collisions even if scope or username formats change. +func (k tokenCacheKey) cacheKeyString() string { + return url.QueryEscape(k.scope) + "|" + url.QueryEscape(k.username) +} + // tokenProvider manages dynamic token acquisition from Central. type tokenProvider struct { // client holds the Central gRPC client. It is stored atomically to safeguard // against data races. client atomic.Pointer[centralv1.TokenServiceClient] clusterIDGetter clusterIDGetter - tokenCache expiringcache.Cache[string, string] - // tokenGroup coalesces concurrent token requests for the same namespace scope. + tokenCache expiringcache.Cache[tokenCacheKey, string] + // tokenGroup coalesces concurrent token requests for the same namespace scope and user. tokenGroup *coalescer.Coalescer[string] } @@ -186,7 +204,7 @@ type tokenProvider struct { func newTokenProvider(clusterIDGetter clusterIDGetter) *tokenProvider { return &tokenProvider{ clusterIDGetter: clusterIDGetter, - tokenCache: expiringcache.NewExpiringCache[string, string](tokenCacheTTL), + tokenCache: expiringcache.NewExpiringCache[tokenCacheKey, string](tokenCacheTTL), tokenGroup: coalescer.New[string](), } } @@ -201,14 +219,14 @@ func (p *tokenProvider) setClient(conn grpc.ClientConnInterface) { p.client.Store(&client) } -// getTokenForScope returns a token for the given namespace scope. +// getTokenForScope returns a token for the given namespace scope and console user. // Scope values: // - "" (empty): Token with empty access scope (authentication only) // - "": Token scoped to the specific namespace // - FullClusterAccessScope ("*"): Token with full cluster access // -// Concurrent requests for the same scope are coalesced to reduce load on Central. -func (p *tokenProvider) getTokenForScope(ctx context.Context, namespaceScope string) (string, error) { +// Concurrent requests for the same scope and user are coalesced to reduce load on Central. +func (p *tokenProvider) getTokenForScope(ctx context.Context, key tokenCacheKey) (string, error) { client := p.client.Load() if client == nil { incrementTokenRequest(tokenResultError) @@ -216,69 +234,71 @@ func (p *tokenProvider) getTokenForScope(ctx context.Context, namespaceScope str } // Fast path: check cache first. - if token, ok := p.tokenCache.Get(namespaceScope); ok { + if token, ok := p.tokenCache.Get(key); ok { incrementTokenRequest(tokenResultCacheHit) return token, nil } - // Slow path: coalesce concurrent requests for the same scope. - return p.tokenGroup.Coalesce(ctx, namespaceScope, func() (string, error) { //nolint:wrapcheck + coalescerKey := key.cacheKeyString() + + // Slow path: coalesce concurrent requests for the same scope and user. + return p.tokenGroup.Coalesce(ctx, coalescerKey, func() (string, error) { //nolint:wrapcheck // Double-check cache inside coalesce to avoid redundant API calls. - if token, ok := p.tokenCache.Get(namespaceScope); ok { + if token, ok := p.tokenCache.Get(key); ok { incrementTokenRequest(tokenResultCacheHit) return token, nil } - log.Debugf("Token cache miss for namespace scope %q, requesting from Central", namespaceScope) + log.Debugf("Token cache miss for namespace scope %q, requesting from Central", key.scope) // Use a background context with timeout to ensure the shared function is independent // of the initial request context while still having a bounded lifetime. ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout) defer cancel() - token, err := p.requestToken(ctx, *client, namespaceScope) + token, err := p.requestToken(ctx, *client, key) if err != nil { incrementTokenRequest(tokenResultError) return "", err } - p.tokenCache.Add(namespaceScope, token) + p.tokenCache.Add(key, token) incrementTokenRequest(tokenResultSuccess) return token, nil }) } -// requestToken performs the RPC call to Central to generate a token for the given scope. -func (p *tokenProvider) requestToken(ctx context.Context, client centralv1.TokenServiceClient, namespaceScope string) (string, error) { - req, err := p.buildTokenRequest(namespaceScope) +// requestToken performs the RPC call to Central to generate a token for the given scope and user. +func (p *tokenProvider) requestToken(ctx context.Context, client centralv1.TokenServiceClient, key tokenCacheKey) (string, error) { + req, err := p.buildTokenRequest(key.scope, key.username) if err != nil { return "", errors.Wrap(err, "building token request") } resp, err := client.GenerateTokenForPermissionsAndScope(ctx, req) if err != nil { - return "", errors.Wrapf(err, "requesting token from Central for scope %q", namespaceScope) + return "", errors.Wrapf(err, "requesting token from Central for scope %q", key.scope) } token := resp.GetToken() if token == "" { - return "", errors.Errorf("received empty token from Central for scope %q", namespaceScope) + return "", errors.Errorf("received empty token from Central for scope %q", key.scope) } return token, nil } -// invalidateToken removes the cached token for the given scope. +// invalidateToken removes the cached token for the given scope and user. // It also removes the coalescer key so subsequent callers will // trigger a fresh token request rather than joining any in-progress request. // Note: This does not cancel already running requests; they will complete // normally but their results will not be used by new callers. -func (p *tokenProvider) invalidateToken(scope string) { - p.tokenCache.Remove(scope) - p.tokenGroup.Forget(scope) +func (p *tokenProvider) invalidateToken(key tokenCacheKey) { + p.tokenCache.Remove(key) + p.tokenGroup.Forget(key.cacheKeyString()) } -// buildTokenRequest creates the token request based on the namespace scope. +// buildTokenRequest creates the token request based on the namespace scope and username. // Returns an error if the cluster ID is not available yet. -func (p *tokenProvider) buildTokenRequest(namespaceScope string) (*centralv1.GenerateTokenForPermissionsAndScopeRequest, error) { +func (p *tokenProvider) buildTokenRequest(namespaceScope string, username string) (*centralv1.GenerateTokenForPermissionsAndScopeRequest, error) { clusterID := p.clusterIDGetter.GetNoWait() if clusterID == "" { return nil, errors.Wrap(errServiceUnavailable, "cluster ID not available") @@ -287,6 +307,7 @@ func (p *tokenProvider) buildTokenRequest(namespaceScope string) (*centralv1.Gen req := ¢ralv1.GenerateTokenForPermissionsAndScopeRequest{ Permissions: tokenPermissions, Lifetime: durationpb.New(tokenTTL), + Requester: username, } switch namespaceScope { diff --git a/sensor/common/centralproxy/transport_test.go b/sensor/common/centralproxy/transport_test.go index 300e8a8e96c5b..3c90df91c4c25 100644 --- a/sensor/common/centralproxy/transport_test.go +++ b/sensor/common/centralproxy/transport_test.go @@ -34,10 +34,10 @@ type fakeTokenServiceClient struct { response *centralv1.GenerateTokenForPermissionsAndScopeResponse err error - // Capture the request for verification + // Capture the request for verification. lastRequest *centralv1.GenerateTokenForPermissionsAndScopeRequest - // callCount tracks the number of RPC calls made (optional, set to non-nil to enable) + // callCount tracks the number of RPC calls made (optional, set to non-nil to enable). callCount *atomic.Int32 } @@ -57,7 +57,7 @@ func (f *fakeTokenServiceClient) GenerateTokenForPermissionsAndScope( func newTestTokenProvider(client centralv1.TokenServiceClient, clusterID string) *tokenProvider { tp := &tokenProvider{ clusterIDGetter: &fakeClusterIDGetter{clusterID: clusterID}, - tokenCache: expiringcache.NewExpiringCache[string, string](tokenCacheTTL), + tokenCache: expiringcache.NewExpiringCache[tokenCacheKey, string](tokenCacheTTL), tokenGroup: coalescer.New[string](), } if client != nil { @@ -66,6 +66,11 @@ func newTestTokenProvider(client centralv1.TokenServiceClient, clusterID string) return tp } +// scopeKey creates a tokenCacheKey with the given scope and empty username. +func scopeKey(scope string) tokenCacheKey { + return tokenCacheKey{scope: scope} +} + func TestScopedTokenTransport_RoundTrip(t *testing.T) { tests := []struct { name string @@ -95,8 +100,10 @@ func TestScopedTokenTransport_RoundTrip(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var capturedAuthHeader string + var capturedNamespaceHeader string mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { capturedAuthHeader = req.Header.Get("Authorization") + capturedNamespaceHeader = req.Header.Get(stackroxNamespaceHeader) return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), @@ -125,6 +132,7 @@ func TestScopedTokenTransport_RoundTrip(t *testing.T) { require.NotNil(t, resp) assert.Equal(t, "Bearer "+tt.expectedToken, capturedAuthHeader) + assert.Empty(t, capturedNamespaceHeader, "namespace scope header should be stripped before forwarding") }) } } @@ -136,7 +144,7 @@ func TestScopedTokenTransport_RoundTrip_Error(t *testing.T) { return nil, nil }) - // Token provider with no client set - will return error + // Token provider with no client set - will return error. transport := &scopedTokenTransport{ base: mockBase, tokenProvider: newTestTokenProvider(nil, "test-cluster-id"), @@ -153,7 +161,7 @@ func TestScopedTokenTransport_RoundTrip_Error(t *testing.T) { t.Run("base transport error propagates", func(t *testing.T) { baseErr := errors.New("connection refused") mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { - // Verify the auth header was set before the error + // Verify the auth header was set before the error. assert.NotEmpty(t, req.Header.Get("Authorization")) return nil, baseErr }) @@ -177,6 +185,82 @@ func TestScopedTokenTransport_RoundTrip_Error(t *testing.T) { }) } +func TestScopedTokenTransport_RoundTrip_ConsoleUser(t *testing.T) { + t.Run("username from context is passed to token request", func(t *testing.T) { + fakeClient := &fakeTokenServiceClient{ + response: ¢ralv1.GenerateTokenForPermissionsAndScopeResponse{ + Token: "user-token", + }, + } + + mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }) + + transport := &scopedTokenTransport{ + base: mockBase, + tokenProvider: newTestTokenProvider(fakeClient, "test-cluster-id"), + } + + req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) + req.Header.Set(stackroxNamespaceHeader, "my-namespace") + req = req.WithContext(contextWithConsoleUser(req.Context(), "console-admin")) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Equal(t, "console-admin", fakeClient.lastRequest.GetRequester()) + }) + + t.Run("different users with same scope get separate cache entries", func(t *testing.T) { + tokenIndex := 0 + tokens := []string{"token-user-a", "token-user-b"} + fakeClient := &dynamicFakeTokenServiceClient{ + getToken: func() string { + token := tokens[tokenIndex] + tokenIndex++ + return token + }, + } + + mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }) + + transport := &scopedTokenTransport{ + base: mockBase, + tokenProvider: newTestTokenProvider(fakeClient, "test-cluster-id"), + } + + // Request as user-a. + reqA := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) + reqA.Header.Set(stackroxNamespaceHeader, "shared-ns") + reqA = reqA.WithContext(contextWithConsoleUser(reqA.Context(), "user-a")) + respA, err := transport.RoundTrip(reqA) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, respA.StatusCode) + + // Request as user-b with same scope should get a different token. + reqB := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) + reqB.Header.Set(stackroxNamespaceHeader, "shared-ns") + reqB = reqB.WithContext(contextWithConsoleUser(reqB.Context(), "user-b")) + respB, err := transport.RoundTrip(reqB) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, respB.StatusCode) + + assert.Equal(t, 2, tokenIndex, "each user should trigger a separate token request") + }) +} + func TestTokenProvider_GetTokenForScope(t *testing.T) { tests := []struct { name string @@ -218,19 +302,19 @@ func TestTokenProvider_GetTokenForScope(t *testing.T) { provider := newTestTokenProvider(fakeClient, tt.clusterID) - token, err := provider.getTokenForScope(context.Background(), tt.namespaceScope) + token, err := provider.getTokenForScope(context.Background(), scopeKey(tt.namespaceScope)) require.NoError(t, err) assert.Equal(t, "test-token-123", token) - // Verify request + // Verify request. req := fakeClient.lastRequest require.NotNil(t, req) - // Verify permissions + // Verify permissions. assert.Equal(t, centralv1.Access_READ_ACCESS, req.GetPermissions()["Image"]) assert.Equal(t, centralv1.Access_READ_ACCESS, req.GetPermissions()["Deployment"]) - // Verify scopes + // Verify scopes. if tt.wantScopes { require.Len(t, req.GetClusterScopes(), 1) scope := req.GetClusterScopes()[0] @@ -249,6 +333,24 @@ func TestTokenProvider_GetTokenForScope(t *testing.T) { } } +func TestTokenProvider_GetTokenForScope_Requester(t *testing.T) { + t.Run("requester is set on token request", func(t *testing.T) { + fakeClient := &fakeTokenServiceClient{ + response: ¢ralv1.GenerateTokenForPermissionsAndScopeResponse{ + Token: "test-token", + }, + } + + provider := newTestTokenProvider(fakeClient, "test-cluster-id") + + key := tokenCacheKey{scope: "my-namespace", username: "console-user"} + _, err := provider.getTokenForScope(context.Background(), key) + require.NoError(t, err) + + assert.Equal(t, "console-user", fakeClient.lastRequest.GetRequester()) + }) +} + func TestTokenProvider_Caching(t *testing.T) { t.Run("tokens are cached", func(t *testing.T) { callCount := 0 @@ -261,20 +363,20 @@ func TestTokenProvider_Caching(t *testing.T) { provider := newTestTokenProvider(fakeClient, "test-cluster-id") - // First call should hit the API - token1, err := provider.getTokenForScope(context.Background(), "namespace-a") + // First call should hit the API. + token1, err := provider.getTokenForScope(context.Background(), scopeKey("namespace-a")) require.NoError(t, err) assert.Equal(t, "cached-token", token1) assert.Equal(t, 1, callCount) - // Second call with same scope should use cache - token2, err := provider.getTokenForScope(context.Background(), "namespace-a") + // Second call with same scope should use cache. + token2, err := provider.getTokenForScope(context.Background(), scopeKey("namespace-a")) require.NoError(t, err) assert.Equal(t, "cached-token", token2) assert.Equal(t, 1, callCount, "should use cached token") - // Third call with different scope should hit API again - token3, err := provider.getTokenForScope(context.Background(), "namespace-b") + // Third call with different scope should hit API again. + token3, err := provider.getTokenForScope(context.Background(), scopeKey("namespace-b")) require.NoError(t, err) assert.Equal(t, "cached-token", token3) assert.Equal(t, 2, callCount, "different scope should request new token") @@ -294,33 +396,66 @@ func TestTokenProvider_Caching(t *testing.T) { provider := newTestTokenProvider(fakeClient, "test-cluster-id") - // Request for scope A - tokenA, err := provider.getTokenForScope(context.Background(), "scope-a") + // Request for scope A. + tokenA, err := provider.getTokenForScope(context.Background(), scopeKey("scope-a")) require.NoError(t, err) assert.Equal(t, "token-1", tokenA) - // Request for scope B (different scope, should get new token) - tokenB, err := provider.getTokenForScope(context.Background(), "scope-b") + // Request for scope B (different scope, should get new token). + tokenB, err := provider.getTokenForScope(context.Background(), scopeKey("scope-b")) require.NoError(t, err) assert.Equal(t, "token-2", tokenB) - // Request for scope A again (should use cached) - tokenACached, err := provider.getTokenForScope(context.Background(), "scope-a") + // Request for scope A again (should use cached). + tokenACached, err := provider.getTokenForScope(context.Background(), scopeKey("scope-a")) require.NoError(t, err) assert.Equal(t, "token-1", tokenACached) - // Request for scope B again (should use cached) - tokenBCached, err := provider.getTokenForScope(context.Background(), "scope-b") + // Request for scope B again (should use cached). + tokenBCached, err := provider.getTokenForScope(context.Background(), scopeKey("scope-b")) require.NoError(t, err) assert.Equal(t, "token-2", tokenBCached) }) + + t.Run("different users with same scope get separate cache entries", func(t *testing.T) { + tokenIndex := 0 + tokens := []string{"token-user-a", "token-user-b"} + + fakeClient := &dynamicFakeTokenServiceClient{ + getToken: func() string { + token := tokens[tokenIndex] + tokenIndex++ + return token + }, + } + + provider := newTestTokenProvider(fakeClient, "test-cluster-id") + + keyA := tokenCacheKey{scope: "shared-scope", username: "user-a"} + keyB := tokenCacheKey{scope: "shared-scope", username: "user-b"} + + // Request for user-a. + tokenA, err := provider.getTokenForScope(context.Background(), keyA) + require.NoError(t, err) + assert.Equal(t, "token-user-a", tokenA) + + // Request for user-b (same scope, different user, should get new token). + tokenB, err := provider.getTokenForScope(context.Background(), keyB) + require.NoError(t, err) + assert.Equal(t, "token-user-b", tokenB) + + // Request for user-a again (should use cached). + tokenACached, err := provider.getTokenForScope(context.Background(), keyA) + require.NoError(t, err) + assert.Equal(t, "token-user-a", tokenACached) + }) } func TestTokenProvider_ErrorHandling(t *testing.T) { t.Run("no client returns error", func(t *testing.T) { provider := newTestTokenProvider(nil, "test-cluster-id") - _, err := provider.getTokenForScope(context.Background(), "namespace") + _, err := provider.getTokenForScope(context.Background(), scopeKey("namespace")) require.Error(t, err) assert.Contains(t, err.Error(), "not initialized") }) @@ -328,18 +463,18 @@ func TestTokenProvider_ErrorHandling(t *testing.T) { t.Run("empty token response returns error", func(t *testing.T) { fakeClient := &fakeTokenServiceClient{ response: ¢ralv1.GenerateTokenForPermissionsAndScopeResponse{ - Token: "", // Empty token + Token: "", // Empty token. }, } provider := newTestTokenProvider(fakeClient, "test-cluster-id") - _, err := provider.getTokenForScope(context.Background(), "namespace") + _, err := provider.getTokenForScope(context.Background(), scopeKey("namespace")) require.Error(t, err) assert.Contains(t, err.Error(), "empty token") - // Ensure nothing was cached for this scope - _, found := provider.tokenCache.Get("namespace") + // Ensure nothing was cached for this scope. + _, found := provider.tokenCache.Get(scopeKey("namespace")) assert.False(t, found) }) @@ -352,12 +487,12 @@ func TestTokenProvider_ErrorHandling(t *testing.T) { provider := newTestTokenProvider(fakeClient, "test-cluster-id") - _, err := provider.getTokenForScope(context.Background(), "namespace-error") + _, err := provider.getTokenForScope(context.Background(), scopeKey("namespace-error")) require.Error(t, err) assert.Contains(t, err.Error(), "rpc failure") - // Ensure token is not cached for the failing scope - _, found := provider.tokenCache.Get("namespace-error") + // Ensure token is not cached for the failing scope. + _, found := provider.tokenCache.Get(scopeKey("namespace-error")) assert.False(t, found) }) } @@ -368,16 +503,17 @@ func TestBuildTokenRequest(t *testing.T) { } t.Run("empty scope", func(t *testing.T) { - req, err := provider.buildTokenRequest("") + req, err := provider.buildTokenRequest("", "") require.NoError(t, err) assert.Equal(t, centralv1.Access_READ_ACCESS, req.GetPermissions()["Image"]) assert.Equal(t, centralv1.Access_READ_ACCESS, req.GetPermissions()["Deployment"]) assert.Empty(t, req.GetClusterScopes()) assert.NotNil(t, req.GetLifetime()) + assert.Empty(t, req.GetRequester()) }) t.Run("specific namespace", func(t *testing.T) { - req, err := provider.buildTokenRequest("prod") + req, err := provider.buildTokenRequest("prod", "") require.NoError(t, err) require.Len(t, req.GetClusterScopes(), 1) assert.Equal(t, "my-cluster-id", req.GetClusterScopes()[0].GetClusterId()) @@ -386,7 +522,7 @@ func TestBuildTokenRequest(t *testing.T) { }) t.Run("cluster-wide scope", func(t *testing.T) { - req, err := provider.buildTokenRequest(FullClusterAccessScope) + req, err := provider.buildTokenRequest(FullClusterAccessScope, "") require.NoError(t, err) require.Len(t, req.GetClusterScopes(), 1) assert.Equal(t, "my-cluster-id", req.GetClusterScopes()[0].GetClusterId()) @@ -398,11 +534,17 @@ func TestBuildTokenRequest(t *testing.T) { emptyProvider := &tokenProvider{ clusterIDGetter: &fakeClusterIDGetter{clusterID: ""}, } - req, err := emptyProvider.buildTokenRequest("namespace") + req, err := emptyProvider.buildTokenRequest("namespace", "") require.Error(t, err) assert.Nil(t, req) assert.Contains(t, err.Error(), "cluster ID not available") }) + + t.Run("requester is set in request", func(t *testing.T) { + req, err := provider.buildTokenRequest("prod", "console-admin") + require.NoError(t, err) + assert.Equal(t, "console-admin", req.GetRequester()) + }) } // dynamicFakeTokenServiceClient allows dynamic token generation for testing. @@ -433,7 +575,7 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { requestCount := 0 mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { requestCount++ - // First request returns 401, retry returns 200 + // First request returns 401, retry returns 200. if req.Header.Get("Authorization") == "Bearer token-1" { return &http.Response{ StatusCode: http.StatusUnauthorized, @@ -456,7 +598,7 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) req.Header.Set(stackroxNamespaceHeader, "test-namespace") - // Single RoundTrip should retry internally and return success + // Single RoundTrip should retry internally and return success. resp, err := transport.RoundTrip(req) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode, "should return success after retry") @@ -476,7 +618,7 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { requestCount := 0 mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { requestCount++ - // First request returns 403, retry returns 200 + // First request returns 403, retry returns 200. if req.Header.Get("Authorization") == "Bearer token-1" { return &http.Response{ StatusCode: http.StatusForbidden, @@ -499,7 +641,7 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) req.Header.Set(stackroxNamespaceHeader, "test-namespace") - // Single RoundTrip should retry internally and return success + // Single RoundTrip should retry internally and return success. resp, err := transport.RoundTrip(req) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode, "should return success after retry") @@ -519,7 +661,7 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { requestCount := 0 mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { requestCount++ - // Always return 401 + // Always return 401. return &http.Response{ StatusCode: http.StatusUnauthorized, Body: io.NopCloser(strings.NewReader(`{"error":"unauthorized"}`)), @@ -535,7 +677,7 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) req.Header.Set(stackroxNamespaceHeader, "test-namespace") - // Should retry once and then return the 401 + // Should retry once and then return the 401. resp, err := transport.RoundTrip(req) require.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "should return 401 after retry fails") @@ -576,12 +718,12 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) req.Header.Set(stackroxNamespaceHeader, "test-namespace") - // First request + // First request. _, err := transport.RoundTrip(req) require.NoError(t, err) assert.Equal(t, 1, callCount) - // Second request - should use cached token + // Second request - should use cached token. _, err = transport.RoundTrip(req) require.NoError(t, err) assert.Equal(t, 1, callCount, "token should still be cached for status %d", statusCode) @@ -611,12 +753,12 @@ func TestScopedTokenTransport_InvalidateOnUnauthorized(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v1/alerts", nil) req.Header.Set(stackroxNamespaceHeader, "test-namespace") - // First request - transport error + // First request - transport error. _, err := transport.RoundTrip(req) require.Error(t, err) assert.Equal(t, 1, callCount) - // Second request - should use cached token (error didn't invalidate) + // Second request - should use cached token (error didn't invalidate). _, err = transport.RoundTrip(req) require.Error(t, err) assert.Equal(t, 1, callCount, "token should still be cached after transport error") @@ -634,24 +776,25 @@ func TestTokenProvider_InvalidateToken(t *testing.T) { } provider := newTestTokenProvider(fakeClient, "test-cluster-id") + key := scopeKey("my-scope") - // Get token (causes cache) - token1, err := provider.getTokenForScope(context.Background(), "my-scope") + // Get token (causes cache). + token1, err := provider.getTokenForScope(context.Background(), key) require.NoError(t, err) assert.Equal(t, "token-1", token1) assert.Equal(t, 1, callCount) - // Get again - should be cached - token2, err := provider.getTokenForScope(context.Background(), "my-scope") + // Get again - should be cached. + token2, err := provider.getTokenForScope(context.Background(), key) require.NoError(t, err) assert.Equal(t, "token-1", token2) assert.Equal(t, 1, callCount) - // Invalidate - provider.invalidateToken("my-scope") + // Invalidate. + provider.invalidateToken(key) - // Get again - should fetch new token - token3, err := provider.getTokenForScope(context.Background(), "my-scope") + // Get again - should fetch new token. + token3, err := provider.getTokenForScope(context.Background(), key) require.NoError(t, err) assert.Equal(t, "token-2", token3) assert.Equal(t, 2, callCount) @@ -667,29 +810,78 @@ func TestTokenProvider_InvalidateToken(t *testing.T) { } provider := newTestTokenProvider(fakeClient, "test-cluster-id") + keyA := scopeKey("scope-a") + keyB := scopeKey("scope-b") - // Cache tokens for two scopes - _, err := provider.getTokenForScope(context.Background(), "scope-a") + // Cache tokens for two scopes. + _, err := provider.getTokenForScope(context.Background(), keyA) require.NoError(t, err) - _, err = provider.getTokenForScope(context.Background(), "scope-b") + _, err = provider.getTokenForScope(context.Background(), keyB) require.NoError(t, err) assert.Equal(t, 2, callCount) - // Invalidate only scope-a - provider.invalidateToken("scope-a") + // Invalidate only scope-a. + provider.invalidateToken(keyA) - // scope-a should get new token - tokenA, err := provider.getTokenForScope(context.Background(), "scope-a") + // scope-a should get new token. + tokenA, err := provider.getTokenForScope(context.Background(), keyA) require.NoError(t, err) assert.Equal(t, "token-3", tokenA) assert.Equal(t, 3, callCount) - // scope-b should still be cached - tokenB, err := provider.getTokenForScope(context.Background(), "scope-b") + // scope-b should still be cached. + tokenB, err := provider.getTokenForScope(context.Background(), keyB) require.NoError(t, err) assert.Equal(t, "token-2", tokenB) assert.Equal(t, 3, callCount) }) + + t.Run("invalidateToken only affects specified user within same scope", func(t *testing.T) { + tokenIndex := 0 + tokens := []string{ + "token-user-a-1", // Initial token for user-a. + "token-user-b-1", // Initial token for user-b. + "token-user-a-2", // Refreshed token for user-a after invalidation. + } + + fakeClient := &dynamicFakeTokenServiceClient{ + getToken: func() string { + token := tokens[tokenIndex] + tokenIndex++ + return token + }, + } + + provider := newTestTokenProvider(fakeClient, "test-cluster-id") + + keyA := tokenCacheKey{scope: "shared-scope", username: "user-a"} + keyB := tokenCacheKey{scope: "shared-scope", username: "user-b"} + + // Cache tokens for both users. + tokenA1, err := provider.getTokenForScope(context.Background(), keyA) + require.NoError(t, err) + assert.Equal(t, "token-user-a-1", tokenA1) + + tokenB1, err := provider.getTokenForScope(context.Background(), keyB) + require.NoError(t, err) + assert.Equal(t, "token-user-b-1", tokenB1) + + // Invalidate only user-a's token. + provider.invalidateToken(keyA) + + // User-a should trigger a new RPC and get a new token. + tokenA2, err := provider.getTokenForScope(context.Background(), keyA) + require.NoError(t, err) + assert.Equal(t, "token-user-a-2", tokenA2) + + // User-b should still use the cached token without triggering a new RPC. + tokenB2, err := provider.getTokenForScope(context.Background(), keyB) + require.NoError(t, err) + assert.Equal(t, "token-user-b-1", tokenB2) + + // Exactly three RPCs: initial user-a, initial user-b, refreshed user-a. + assert.Equal(t, 3, tokenIndex) + }) } func TestScopedTokenTransport_RetryWithRequestBody(t *testing.T) { @@ -706,7 +898,7 @@ func TestScopedTokenTransport_RetryWithRequestBody(t *testing.T) { requestCount := 0 mockBase := roundTripperFunc(func(req *http.Request) (*http.Response, error) { requestCount++ - // Read the body to verify it's available + // Read the body to verify it's available. body, _ := io.ReadAll(req.Body) bodiesReceived = append(bodiesReceived, string(body)) @@ -739,7 +931,7 @@ func TestScopedTokenTransport_RetryWithRequestBody(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode, "should return success after retry") assert.Equal(t, 2, tokenCallCount, "should have requested two tokens") assert.Equal(t, 2, requestCount, "should have made two requests") - // Both requests should have received the body + // Both requests should have received the body. assert.Equal(t, []string{bodyContent, bodyContent}, bodiesReceived, "both requests should receive the body") }) @@ -821,7 +1013,7 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { // This is used for deterministic testing of request coalescing without timing dependencies. type barrierFakeTokenServiceClient struct { getToken func() string - barrier <-chan struct{} // If non-nil, blocks until closed before returning + barrier <-chan struct{} // If non-nil, blocks until closed before returning. } func (b *barrierFakeTokenServiceClient) GenerateTokenForPermissionsAndScope( @@ -860,30 +1052,32 @@ func TestTokenProvider_Coalescing(t *testing.T) { tokens := make([]string, numGoroutines) errs := make([]error, numGoroutines) - // Use a separate WaitGroup to track when all goroutines have started + // Use a separate WaitGroup to track when all goroutines have started. var startWg sync.WaitGroup startWg.Add(numGoroutines) - // Launch concurrent requests for the same scope + key := scopeKey("shared-scope") + + // Launch concurrent requests for the same scope. for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(idx int) { defer wg.Done() - startWg.Done() // Signal this goroutine has started - tokens[idx], errs[idx] = provider.getTokenForScope(context.Background(), "shared-scope") + startWg.Done() // Signal this goroutine has started. + tokens[idx], errs[idx] = provider.getTokenForScope(context.Background(), key) }(i) } - // Wait for all goroutines to start, then release the barrier + // Wait for all goroutines to start, then release the barrier. startWg.Wait() close(barrier) wg.Wait() - // Verify only ONE RPC call was made + // Verify only ONE RPC call was made. assert.Equal(t, int32(1), callCount.Load(), "expected exactly 1 RPC call for %d concurrent requests", numGoroutines) - // Verify all goroutines got the same token + // Verify all goroutines got the same token. for i := 0; i < numGoroutines; i++ { require.NoError(t, errs[i]) assert.Equal(t, "coalesced-token", tokens[i]) @@ -897,18 +1091,19 @@ func TestTokenProvider_Coalescing(t *testing.T) { } provider := newTestTokenProvider(fakeClient, "test-cluster-id") + key := scopeKey("error-scope") - // First request fails + // First request fails. fakeClient.err = errors.New("transient failure") - _, err := provider.getTokenForScope(context.Background(), "error-scope") + _, err := provider.getTokenForScope(context.Background(), key) require.Error(t, err) // singleflight removes the key after the call completes with error - // so a second request will trigger a new RPC call + // so a second request will trigger a new RPC call. fakeClient.err = nil fakeClient.response = ¢ralv1.GenerateTokenForPermissionsAndScopeResponse{Token: "recovered-token"} - token, err := provider.getTokenForScope(context.Background(), "error-scope") + token, err := provider.getTokenForScope(context.Background(), key) require.NoError(t, err) assert.Equal(t, "recovered-token", token) assert.Equal(t, int32(2), callCount.Load(), "expected two RPC calls (one failure, one success)") @@ -925,18 +1120,19 @@ func TestTokenProvider_Coalescing(t *testing.T) { } provider := newTestTokenProvider(fakeClient, "test-cluster-id") + key := scopeKey("invalidate-scope") - // First request - token1, err := provider.getTokenForScope(context.Background(), "invalidate-scope") + // First request. + token1, err := provider.getTokenForScope(context.Background(), key) require.NoError(t, err) assert.Equal(t, "token-1", token1) assert.Equal(t, int32(1), callCount.Load()) - // Invalidate - provider.invalidateToken("invalidate-scope") + // Invalidate. + provider.invalidateToken(key) - // Second request should trigger new RPC - token2, err := provider.getTokenForScope(context.Background(), "invalidate-scope") + // Second request should trigger new RPC. + token2, err := provider.getTokenForScope(context.Background(), key) require.NoError(t, err) assert.Equal(t, "token-2", token2) assert.Equal(t, int32(2), callCount.Load())