diff --git a/pkg/sync/common_aliases.go b/pkg/sync/common_aliases.go index 41ef90e6f8d98..44305380a8902 100644 --- a/pkg/sync/common_aliases.go +++ b/pkg/sync/common_aliases.go @@ -10,3 +10,6 @@ type WaitGroup = sync.WaitGroup // Locker is an alias for `sync.Locker`. type Locker = sync.Locker + +// Map is an alias for `sync.Map`. +type Map = sync.Map diff --git a/sensor/kubernetes/localscanner/certificate_request.go b/sensor/kubernetes/localscanner/certificate_request.go new file mode 100644 index 0000000000000..445ea73e4beec --- /dev/null +++ b/sensor/kubernetes/localscanner/certificate_request.go @@ -0,0 +1,55 @@ +package localscanner + +import ( + "context" + + "github.com/stackrox/rox/generated/internalapi/central" +) + +var ( + _ certificateRequest = (*certRequestSyncImpl)(nil) +) + +// certificateRequest request a new set of local scanner certificates to central. +type certificateRequest interface { + requestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) +} + +type certRequestSyncImpl struct { + requestID string + msgFromSensorC msgFromSensorC + msgToSensorC msgToSensorC +} + +func (i *certRequestSyncImpl) requestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) { + msg := ¢ral.MsgFromSensor{ + Msg: ¢ral.MsgFromSensor_IssueLocalScannerCertsRequest{ + IssueLocalScannerCertsRequest: ¢ral.IssueLocalScannerCertsRequest{ + RequestId: i.requestID, + }, + }, + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case i.msgFromSensorC <- msg: + log.Debugf("request to issue local Scanner certificates sent to Central successfully: %v", msg) + } + + var response *central.IssueLocalScannerCertsResponse + for response == nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case newResponse := <-i.msgToSensorC: + if newResponse.GetRequestId() != i.requestID { + log.Debugf("request id %q does not match %q, skipping request", response.GetRequestId(), + i.requestID) + } else { + response = newResponse + } + } + } + + return response, nil +} diff --git a/sensor/kubernetes/localscanner/certificate_requester.go b/sensor/kubernetes/localscanner/certificate_requester.go new file mode 100644 index 0000000000000..6e786b4dac904 --- /dev/null +++ b/sensor/kubernetes/localscanner/certificate_requester.go @@ -0,0 +1,81 @@ +package localscanner + +import ( + "context" + + "github.com/stackrox/rox/generated/internalapi/central" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/pkg/sync" + "github.com/stackrox/rox/pkg/uuid" +) + +var ( + log = logging.LoggerForModule() + _ CertificateRequester = (*certificateRequesterImpl)(nil) +) + +// CertificateRequester request a new set of local scanner certificates to central. +type CertificateRequester interface { + Start() + Stop() + RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) +} + +// NewCertificateRequester creates a new certificate requester that communicates through +// the specified channels and initializes a new request ID for reach request. +// To use it call Start, and then make requests with RequestCertificates, concurrent requests are supported. +// This assumes that the certificate requester is the only consumer of msgToSensorC. +func NewCertificateRequester(msgFromSensorC msgFromSensorC, msgToSensorC msgToSensorC) CertificateRequester { + return &certificateRequesterImpl{ + stopC: concurrency.NewErrorSignal(), + msgFromSensorC: msgFromSensorC, + msgToSensorC: msgToSensorC, + } +} + +type msgFromSensorC chan *central.MsgFromSensor +type msgToSensorC chan *central.IssueLocalScannerCertsResponse +type certificateRequesterImpl struct { + stopC concurrency.ErrorSignal + msgFromSensorC msgFromSensorC + msgToSensorC msgToSensorC + requests sync.Map +} + +func (m *certificateRequesterImpl) Start() { + go m.forwardMessagesToSensor() +} + +func (m *certificateRequesterImpl) Stop() { + m.stopC.Signal() +} + +func (m *certificateRequesterImpl) forwardMessagesToSensor() { + for { + select { + case <-m.stopC.Done(): + return + case msg := <-m.msgToSensorC: + requestC, ok := m.requests.Load(msg.GetRequestId()) + if ok { + requestC.(msgToSensorC) <- msg + } else { + log.Debugf("request ID %q does not match any known request ID, skipping request", + msg.GetRequestId()) // FIXME debug + } + } + } +} + +func (m *certificateRequesterImpl) RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) { + request := &certRequestSyncImpl{ + requestID: uuid.NewV4().String(), + msgFromSensorC: m.msgFromSensorC, + msgToSensorC: make(msgToSensorC), + } + m.requests.Store(request.requestID, request.msgToSensorC) + response, err := request.requestCertificates(ctx) + m.requests.Delete(request.requestID) + return response, err +} diff --git a/sensor/kubernetes/localscanner/certificate_requester_test.go b/sensor/kubernetes/localscanner/certificate_requester_test.go new file mode 100644 index 0000000000000..42ed590e4d525 --- /dev/null +++ b/sensor/kubernetes/localscanner/certificate_requester_test.go @@ -0,0 +1,141 @@ +package localscanner + +import ( + "context" + "testing" + "time" + + "github.com/stackrox/rox/generated/internalapi/central" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stretchr/testify/suite" +) + +func TestCertificateRequester(t *testing.T) { + suite.Run(t, new(certificateRequesterSuite)) +} + +type certificateRequesterSuite struct { + suite.Suite + msgFromSensorC msgFromSensorC + msgToSensorC msgToSensorC + requester CertificateRequester +} + +func (s *certificateRequesterSuite) SetupTest() { + s.msgFromSensorC = make(msgFromSensorC) + s.msgToSensorC = make(msgToSensorC) + s.requester = NewCertificateRequester(s.msgFromSensorC, s.msgToSensorC) + s.requester.Start() +} + +func (s *certificateRequesterSuite) TearDownTest() { + s.requester.Stop() +} + +func (s *certificateRequesterSuite) TestRequestCancellation() { + requestCtx, cancelRequestCtx := context.WithCancel(context.Background()) + doneErrSig := concurrency.NewErrorSignal() + + go func() { + certs, err := s.requester.RequestCertificates(requestCtx) + s.Nil(certs) + doneErrSig.SignalWithError(err) + }() + cancelRequestCtx() + + requestErr, ok := doneErrSig.WaitWithTimeout(100 * time.Millisecond) + s.Require().True(ok) + s.Equal(context.Canceled, requestErr) +} + +func (s *certificateRequesterSuite) TestRequestSuccess() { + waitCtx, cancelWaitCtx := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelWaitCtx() + + responseC := make(msgToSensorC) + var interceptedRequestID string + go func() { + select { + case <-waitCtx.Done(): + return + case request := <-s.msgFromSensorC: + interceptedRequestID = request.GetIssueLocalScannerCertsRequest().GetRequestId() + s.NotEmpty(interceptedRequestID) + s.msgToSensorC <- ¢ral.IssueLocalScannerCertsResponse{ + RequestId: interceptedRequestID, + } + } + }() + + go func() { + response, err := s.requester.RequestCertificates(waitCtx) + s.NoError(err) + responseC <- response + }() + + select { + case response := <-responseC: + s.Equal(interceptedRequestID, response.GetRequestId()) + case <-waitCtx.Done(): + s.Require().Fail("timeout reached") + } +} + +func (s *certificateRequesterSuite) TestResponsesWithUnknownIDAreIgnored() { + waitCtx, cancelWaitCtx := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancelWaitCtx() + doneErrSig := concurrency.NewErrorSignal() + + go func() { + select { + case <-waitCtx.Done(): + case <-s.msgFromSensorC: + select { + case <-waitCtx.Done(): + // Request with different request ID should be ignored. + case s.msgToSensorC <- ¢ral.IssueLocalScannerCertsResponse{RequestId: ""}: + } + } + }() + + go func() { + certs, err := s.requester.RequestCertificates(waitCtx) + s.Nil(certs) + doneErrSig.SignalWithError(err) + }() + + requestErr, ok := doneErrSig.WaitWithTimeout(100 * time.Millisecond) + s.Require().True(ok) + s.Equal(context.DeadlineExceeded, requestErr) +} + +func (s *certificateRequesterSuite) TestRequestConcurrentRequestDoNotInterfere() { + waitCtx, cancelWaitCtx := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancelWaitCtx() + numConcurrentRequests := 3 + waitGroup := concurrency.NewWaitGroup(numConcurrentRequests) + + for i := 0; i < numConcurrentRequests; i++ { + go func() { + select { + case <-waitCtx.Done(): + return + case request := <-s.msgFromSensorC: + interceptedRequestID := request.GetIssueLocalScannerCertsRequest().GetRequestId() + s.NotEmpty(interceptedRequestID) + s.msgToSensorC <- ¢ral.IssueLocalScannerCertsResponse{ + RequestId: interceptedRequestID, + } + } + }() + + go func() { + _, err := s.requester.RequestCertificates(waitCtx) + s.NoError(err) + waitGroup.Add(-1) + }() + } + + ok := concurrency.WaitWithTimeout(&waitGroup, 100*time.Millisecond) + s.Require().True(ok) +}