diff --git a/pkg/sync/common_aliases.go b/pkg/sync/common_aliases.go index 41ef90e6f8d98..44305380a8902 100644 --- a/pkg/sync/common_aliases.go +++ b/pkg/sync/common_aliases.go @@ -10,3 +10,6 @@ type WaitGroup = sync.WaitGroup // Locker is an alias for `sync.Locker`. type Locker = sync.Locker + +// Map is an alias for `sync.Map`. +type Map = sync.Map diff --git a/sensor/kubernetes/localscanner/certificate_requester.go b/sensor/kubernetes/localscanner/certificate_requester.go new file mode 100644 index 0000000000000..4de43ad69a69e --- /dev/null +++ b/sensor/kubernetes/localscanner/certificate_requester.go @@ -0,0 +1,125 @@ +package localscanner + +import ( + "context" + + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/internalapi/central" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/pkg/sync" + "github.com/stackrox/rox/pkg/uuid" +) + +var ( + // ErrCertificateRequesterStopped is returned by RequestCertificates when the certificate + // requester is not initialized. + ErrCertificateRequesterStopped = errors.New("stopped") + log = logging.LoggerForModule() + _ CertificateRequester = (*certificateRequesterImpl)(nil) +) + +// CertificateRequester requests a new set of local scanner certificates from central. +type CertificateRequester interface { + Start() + Stop() + RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) +} + +// NewCertificateRequester creates a new certificate requester that communicates through +// the specified channels and initializes a new request ID for reach request. +// To use it call Start, and then make requests with RequestCertificates, concurrent requests are supported. +// This assumes that the returned certificate requester is the only consumer of `receiveC`. +func NewCertificateRequester(sendC chan<- *central.MsgFromSensor, + receiveC <-chan *central.IssueLocalScannerCertsResponse) CertificateRequester { + return &certificateRequesterImpl{ + sendC: sendC, + receiveC: receiveC, + } +} + +type certificateRequesterImpl struct { + sendC chan<- *central.MsgFromSensor + receiveC <-chan *central.IssueLocalScannerCertsResponse + stopC concurrency.ErrorSignal + requests sync.Map +} + +// Start makes the certificate requester listen to `receiveC` and forward responses to any request that is running +// as a call to RequestCertificates. +func (r *certificateRequesterImpl) Start() { + r.stopC.Reset() + go r.dispatchResponses() +} + +// Stop makes the certificate stop forwarding responses to running requests. Subsequent calls to RequestCertificates +// will fail with ErrCertificateRequesterStopped. +// Currently active calls to RequestCertificates will continue running until cancelled or timed out via the +// provided context. +func (r *certificateRequesterImpl) Stop() { + r.stopC.Signal() +} + +func (r *certificateRequesterImpl) dispatchResponses() { + for { + select { + case <-r.stopC.Done(): + return + case msg := <-r.receiveC: + responseC, ok := r.requests.Load(msg.GetRequestId()) + if !ok { + log.Debugf("request ID %q does not match any known request ID, dropping response", + msg.GetRequestId()) + continue + } + r.requests.Delete(msg.GetRequestId()) + // Doesn't block even if the corresponding call to RequestCertificates is cancelled and no one + // ever reads this, because requestC has buffer of 1, and we removed it from `r.request` above, + // in case we get more than 1 response for `msg.GetRequestId()`. + responseC.(chan *central.IssueLocalScannerCertsResponse) <- msg + } + } +} + +// RequestCertificates makes a new request for a new set of local scanner certificates from central. +// This assumes the certificate requester is started, otherwise this returns ErrCertificateRequesterStopped. +func (r *certificateRequesterImpl) RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) { + requestID := uuid.NewV4().String() + receiveC := make(chan *central.IssueLocalScannerCertsResponse, 1) + r.requests.Store(requestID, receiveC) + // Always delete this entry when leaving this scope to account for requests that are never responded, to avoid + // having entries in `r.requests` that are never removed. + defer r.requests.Delete(requestID) + + if err := r.send(ctx, requestID); err != nil { + return nil, err + } + return receive(ctx, receiveC) +} + +func (r *certificateRequesterImpl) send(ctx context.Context, requestID string) error { + msg := ¢ral.MsgFromSensor{ + Msg: ¢ral.MsgFromSensor_IssueLocalScannerCertsRequest{ + IssueLocalScannerCertsRequest: ¢ral.IssueLocalScannerCertsRequest{ + RequestId: requestID, + }, + }, + } + select { + case <-r.stopC.Done(): + return r.stopC.ErrorWithDefault(ErrCertificateRequesterStopped) + case <-ctx.Done(): + return ctx.Err() + case r.sendC <- msg: + return nil + } +} + +func receive(ctx context.Context, receiveC <-chan *central.IssueLocalScannerCertsResponse) (*central.IssueLocalScannerCertsResponse, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-receiveC: + return response, nil + } +} diff --git a/sensor/kubernetes/localscanner/certificate_requester_test.go b/sensor/kubernetes/localscanner/certificate_requester_test.go new file mode 100644 index 0000000000000..d69aae3374c38 --- /dev/null +++ b/sensor/kubernetes/localscanner/certificate_requester_test.go @@ -0,0 +1,183 @@ +package localscanner + +import ( + "context" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/stackrox/rox/generated/internalapi/central" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + numConcurrentRequests = 10 +) + +var ( + testTimeout = time.Second +) + +func TestCertificateRequesterRequestFailureIfStopped(t *testing.T) { + testCases := map[string]struct { + startRequester bool + }{ + "requester not started": {false}, + "requester stopped before request": {true}, + } + for tcName, tc := range testCases { + t.Run(tcName, func(t *testing.T) { + f := newFixture(0) + defer f.tearDown() + if tc.startRequester { + f.requester.Start() + f.requester.Stop() + } + + certs, requestErr := f.requester.RequestCertificates(f.ctx) + assert.Nil(t, certs) + assert.Equal(t, ErrCertificateRequesterStopped, requestErr) + }) + } +} + +func TestCertificateRequesterRequestCancellation(t *testing.T) { + f := newFixture(0) + f.requester.Start() + defer f.tearDown() + + f.cancelCtx() + certs, requestErr := f.requester.RequestCertificates(f.ctx) + assert.Nil(t, certs) + assert.Equal(t, context.Canceled, requestErr) +} + +func TestCertificateRequesterRequestSuccess(t *testing.T) { + f := newFixture(0) + f.requester.Start() + defer f.tearDown() + + go f.respondRequest(t, 0, nil) + + response, err := f.requester.RequestCertificates(f.ctx) + assert.NoError(t, err) + assert.Equal(t, f.interceptedRequestID.Load(), response.GetRequestId()) +} + +func TestCertificateRequesterResponsesWithUnknownIDAreIgnored(t *testing.T) { + f := newFixture(100 * time.Millisecond) + f.requester.Start() + defer f.tearDown() + + // Request with different request ID should be ignored. + go f.respondRequest(t, 0, ¢ral.IssueLocalScannerCertsResponse{RequestId: "UNKNOWN"}) + + certs, requestErr := f.requester.RequestCertificates(f.ctx) + assert.Nil(t, certs) + assert.Equal(t, context.DeadlineExceeded, requestErr) +} + +func TestCertificateRequesterRequestConcurrentRequestDoNotInterfere(t *testing.T) { + testCases := map[string]struct { + responseDelayFunc func(requestIndex int) (responseDelay time.Duration) + }{ + "decreasing response delay": {func(requestIndex int) (responseDelay time.Duration) { + // responses are responded increasingly faster, so always out of order. + return time.Duration(numConcurrentRequests-(requestIndex+1)) * 10 * time.Millisecond + }}, + "random response delay": {func(requestIndex int) (responseDelay time.Duration) { + // randomly out of order responses. + return time.Duration(rand.Intn(100)) * time.Millisecond + }}, + } + for tcName, tc := range testCases { + t.Run(tcName, func(t *testing.T) { + f := newFixture(0) + f.requester.Start() + defer f.tearDown() + waitGroup := concurrency.NewWaitGroup(numConcurrentRequests) + + for i := 0; i < numConcurrentRequests; i++ { + i := i + responseDelay := tc.responseDelayFunc(i) + go f.respondRequest(t, responseDelay, nil) + go func() { + defer waitGroup.Add(-1) + _, err := f.requester.RequestCertificates(f.ctx) + assert.NoError(t, err) + }() + } + ok := concurrency.WaitWithTimeout(&waitGroup, time.Duration(numConcurrentRequests)*testTimeout) + require.True(t, ok) + }) + } +} + +type certificateRequesterFixture struct { + sendC chan *central.MsgFromSensor + receiveC chan *central.IssueLocalScannerCertsResponse + requester CertificateRequester + interceptedRequestID *atomic.Value + ctx context.Context + cancelCtx context.CancelFunc +} + +// newFixture creates a new test fixture that uses `timeout` as context timeout if `timeout` is +// not 0, and `testTimeout` otherwise. +func newFixture(timeout time.Duration) *certificateRequesterFixture { + sendC := make(chan *central.MsgFromSensor) + receiveC := make(chan *central.IssueLocalScannerCertsResponse) + requester := NewCertificateRequester(sendC, receiveC) + var interceptedRequestID atomic.Value + if timeout == 0 { + timeout = testTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + return &certificateRequesterFixture{ + sendC: sendC, + receiveC: receiveC, + requester: requester, + ctx: ctx, + cancelCtx: cancel, + interceptedRequestID: &interceptedRequestID, + } +} + +func (f *certificateRequesterFixture) tearDown() { + f.cancelCtx() + f.requester.Stop() +} + +// respondRequest reads a request from `f.sendC` and responds with `responseOverwrite` if not nil, or with +// a response with the same ID as the request otherwise. If `responseDelay` is greater than 0 then this function +// waits for that time before sending the response. +// Before sending the response, it stores in `f.interceptedRequestID` the request ID for the requests read from `f.sendC`. +func (f *certificateRequesterFixture) respondRequest(t *testing.T, responseDelay time.Duration, responseOverwrite *central.IssueLocalScannerCertsResponse) { + select { + case <-f.ctx.Done(): + case request := <-f.sendC: + interceptedRequestID := request.GetIssueLocalScannerCertsRequest().GetRequestId() + assert.NotEmpty(t, interceptedRequestID) + var response *central.IssueLocalScannerCertsResponse + if responseOverwrite != nil { + response = responseOverwrite + } else { + response = ¢ral.IssueLocalScannerCertsResponse{RequestId: interceptedRequestID} + } + f.interceptedRequestID.Store(response.GetRequestId()) + if responseDelay > 0 { + select { + case <-f.ctx.Done(): + return + case <-time.After(responseDelay): + } + } + select { + case <-f.ctx.Done(): + case f.receiveC <- response: + } + } +}