Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/sync/common_aliases.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ type WaitGroup = sync.WaitGroup

// Locker is an alias for `sync.Locker`.
type Locker = sync.Locker

// Map is an alias for `sync.Map`.
type Map = sync.Map
55 changes: 55 additions & 0 deletions sensor/kubernetes/localscanner/certificate_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package localscanner

import (
"context"

"github.com/stackrox/rox/generated/internalapi/central"
)

var (
_ certificateRequest = (*certRequestSyncImpl)(nil)
)

// certificateRequest request a new set of local scanner certificates to central.
type certificateRequest interface {
requestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error)
}

type certRequestSyncImpl struct {
requestID string
msgFromSensorC msgFromSensorC
msgToSensorC msgToSensorC
}

func (i *certRequestSyncImpl) requestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) {
msg := &central.MsgFromSensor{
Msg: &central.MsgFromSensor_IssueLocalScannerCertsRequest{
IssueLocalScannerCertsRequest: &central.IssueLocalScannerCertsRequest{
RequestId: i.requestID,
},
},
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case i.msgFromSensorC <- msg:
log.Debugf("request to issue local Scanner certificates sent to Central successfully: %v", msg)
}

var response *central.IssueLocalScannerCertsResponse
for response == nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case newResponse := <-i.msgToSensorC:
if newResponse.GetRequestId() != i.requestID {
log.Debugf("request id %q does not match %q, skipping request", response.GetRequestId(),
i.requestID)
} else {
response = newResponse
}
}
}

return response, nil
}
81 changes: 81 additions & 0 deletions sensor/kubernetes/localscanner/certificate_requester.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package localscanner

import (
"context"

"github.com/stackrox/rox/generated/internalapi/central"
"github.com/stackrox/rox/pkg/concurrency"
"github.com/stackrox/rox/pkg/logging"
"github.com/stackrox/rox/pkg/sync"
"github.com/stackrox/rox/pkg/uuid"
)

var (
log = logging.LoggerForModule()
_ CertificateRequester = (*certificateRequesterImpl)(nil)
)

// CertificateRequester request a new set of local scanner certificates to central.
type CertificateRequester interface {
Start()
Stop()
RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error)
}

// NewCertificateRequester creates a new certificate requester that communicates through
// the specified channels and initializes a new request ID for reach request.
// To use it call Start, and then make requests with RequestCertificates, concurrent requests are supported.
// This assumes that the certificate requester is the only consumer of msgToSensorC.
func NewCertificateRequester(msgFromSensorC msgFromSensorC, msgToSensorC msgToSensorC) CertificateRequester {
return &certificateRequesterImpl{
stopC: concurrency.NewErrorSignal(),
msgFromSensorC: msgFromSensorC,
msgToSensorC: msgToSensorC,
}
}

type msgFromSensorC chan *central.MsgFromSensor
type msgToSensorC chan *central.IssueLocalScannerCertsResponse
type certificateRequesterImpl struct {
stopC concurrency.ErrorSignal
msgFromSensorC msgFromSensorC
msgToSensorC msgToSensorC
requests sync.Map
}

func (m *certificateRequesterImpl) Start() {
go m.forwardMessagesToSensor()
}

func (m *certificateRequesterImpl) Stop() {
m.stopC.Signal()
}

func (m *certificateRequesterImpl) forwardMessagesToSensor() {
for {
select {
case <-m.stopC.Done():
return
case msg := <-m.msgToSensorC:
requestC, ok := m.requests.Load(msg.GetRequestId())
if ok {
requestC.(msgToSensorC) <- msg
} else {
log.Debugf("request ID %q does not match any known request ID, skipping request",
msg.GetRequestId()) // FIXME debug
}
}
}
}

func (m *certificateRequesterImpl) RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) {
request := &certRequestSyncImpl{
requestID: uuid.NewV4().String(),
msgFromSensorC: m.msgFromSensorC,
msgToSensorC: make(msgToSensorC),
}
m.requests.Store(request.requestID, request.msgToSensorC)
response, err := request.requestCertificates(ctx)
m.requests.Delete(request.requestID)
return response, err
}
141 changes: 141 additions & 0 deletions sensor/kubernetes/localscanner/certificate_requester_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package localscanner

import (
"context"
"testing"
"time"

"github.com/stackrox/rox/generated/internalapi/central"
"github.com/stackrox/rox/pkg/concurrency"
"github.com/stretchr/testify/suite"
)

func TestCertificateRequester(t *testing.T) {
suite.Run(t, new(certificateRequesterSuite))
}

type certificateRequesterSuite struct {
suite.Suite
msgFromSensorC msgFromSensorC
msgToSensorC msgToSensorC
requester CertificateRequester
}

func (s *certificateRequesterSuite) SetupTest() {
s.msgFromSensorC = make(msgFromSensorC)
s.msgToSensorC = make(msgToSensorC)
s.requester = NewCertificateRequester(s.msgFromSensorC, s.msgToSensorC)
s.requester.Start()
}

func (s *certificateRequesterSuite) TearDownTest() {
s.requester.Stop()
}

func (s *certificateRequesterSuite) TestRequestCancellation() {
requestCtx, cancelRequestCtx := context.WithCancel(context.Background())
doneErrSig := concurrency.NewErrorSignal()

go func() {
certs, err := s.requester.RequestCertificates(requestCtx)
s.Nil(certs)
doneErrSig.SignalWithError(err)
}()
cancelRequestCtx()

requestErr, ok := doneErrSig.WaitWithTimeout(100 * time.Millisecond)
s.Require().True(ok)
s.Equal(context.Canceled, requestErr)
}

func (s *certificateRequesterSuite) TestRequestSuccess() {
waitCtx, cancelWaitCtx := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancelWaitCtx()

responseC := make(msgToSensorC)
var interceptedRequestID string
go func() {
select {
case <-waitCtx.Done():
return
case request := <-s.msgFromSensorC:
interceptedRequestID = request.GetIssueLocalScannerCertsRequest().GetRequestId()
s.NotEmpty(interceptedRequestID)
s.msgToSensorC <- &central.IssueLocalScannerCertsResponse{
RequestId: interceptedRequestID,
}
}
}()

go func() {
response, err := s.requester.RequestCertificates(waitCtx)
s.NoError(err)
responseC <- response
}()

select {
case response := <-responseC:
s.Equal(interceptedRequestID, response.GetRequestId())
case <-waitCtx.Done():
s.Require().Fail("timeout reached")
}
}

func (s *certificateRequesterSuite) TestResponsesWithUnknownIDAreIgnored() {
waitCtx, cancelWaitCtx := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancelWaitCtx()
doneErrSig := concurrency.NewErrorSignal()

go func() {
select {
case <-waitCtx.Done():
case <-s.msgFromSensorC:
select {
case <-waitCtx.Done():
// Request with different request ID should be ignored.
case s.msgToSensorC <- &central.IssueLocalScannerCertsResponse{RequestId: ""}:
}
}
}()

go func() {
certs, err := s.requester.RequestCertificates(waitCtx)
s.Nil(certs)
doneErrSig.SignalWithError(err)
}()

requestErr, ok := doneErrSig.WaitWithTimeout(100 * time.Millisecond)
s.Require().True(ok)
s.Equal(context.DeadlineExceeded, requestErr)
}

func (s *certificateRequesterSuite) TestRequestConcurrentRequestDoNotInterfere() {
waitCtx, cancelWaitCtx := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancelWaitCtx()
numConcurrentRequests := 3
waitGroup := concurrency.NewWaitGroup(numConcurrentRequests)

for i := 0; i < numConcurrentRequests; i++ {
go func() {
select {
case <-waitCtx.Done():
return
case request := <-s.msgFromSensorC:
interceptedRequestID := request.GetIssueLocalScannerCertsRequest().GetRequestId()
s.NotEmpty(interceptedRequestID)
s.msgToSensorC <- &central.IssueLocalScannerCertsResponse{
RequestId: interceptedRequestID,
}
}
}()

go func() {
_, err := s.requester.RequestCertificates(waitCtx)
s.NoError(err)
waitGroup.Add(-1)
}()
}

ok := concurrency.WaitWithTimeout(&waitGroup, 100*time.Millisecond)
s.Require().True(ok)
}