Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ccd298a
Create CertificateRequester to request certificates from central
Jan 26, 2022
10cd58c
fix style
Jan 26, 2022
308ca2f
support concurrent requests
Jan 27, 2022
0b561e4
improve comments and logging
Jan 27, 2022
35baa5b
add tests for wrong request ids in response, and concurrent requests
Jan 27, 2022
d05ff9c
fix style
Jan 27, 2022
4c159de
simplify code removing certRequestSyncImpl
Jan 28, 2022
e5cfae8
do not send request until prepared to receive response
Jan 28, 2022
94149b5
remove type alias and rename channel fields
Jan 28, 2022
83e8c5a
prevent deadlock on request cancellation
Jan 28, 2022
76b916a
simplify tests
Jan 28, 2022
5166fb1
improve code style
Jan 28, 2022
a89e594
restrict channel directions
Jan 28, 2022
b477c0e
improvements in test style
Jan 31, 2022
069beed
improve comments
Jan 31, 2022
c7f5bf1
rename requestC to responseC
Jan 31, 2022
07f48cf
improve comments
Jan 31, 2022
a7c024d
check the requester was started on RequestCertificates and fail owise
Jan 31, 2022
cb2527e
wip fixing flaky test
Jan 31, 2022
5fa4ad9
fix bug in Start
Feb 1, 2022
6987961
replace suite by independent suite objects
Feb 1, 2022
a7f4600
improve comments and logs
Feb 1, 2022
43ddb58
add test with requester never started
Feb 1, 2022
03b0153
simplify running failure tests assuming test execution CI task alread…
Feb 1, 2022
b84558b
improve concurrent request test
Feb 1, 2022
3e7eac1
add a deterministic response shuffling strategy
Feb 1, 2022
55d164a
change test case struct for concurrent request test
Feb 2, 2022
35e39a3
improve comment
Feb 2, 2022
8a059f3
make request delay wait cancellable
Feb 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/sync/common_aliases.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ type WaitGroup = sync.WaitGroup

// Locker is an alias for `sync.Locker`.
type Locker = sync.Locker

// Map is an alias for `sync.Map`.
type Map = sync.Map
125 changes: 125 additions & 0 deletions sensor/kubernetes/localscanner/certificate_requester.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package localscanner

import (
"context"

"github.com/pkg/errors"
"github.com/stackrox/rox/generated/internalapi/central"
"github.com/stackrox/rox/pkg/concurrency"
"github.com/stackrox/rox/pkg/logging"
"github.com/stackrox/rox/pkg/sync"
"github.com/stackrox/rox/pkg/uuid"
)

var (
// ErrCertificateRequesterStopped is returned by RequestCertificates when the certificate
// requester is not initialized.
ErrCertificateRequesterStopped = errors.New("stopped")
log = logging.LoggerForModule()
_ CertificateRequester = (*certificateRequesterImpl)(nil)
)

// CertificateRequester requests a new set of local scanner certificates from central.
type CertificateRequester interface {
Start()
Stop()
RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error)
}

// NewCertificateRequester creates a new certificate requester that communicates through
// the specified channels and initializes a new request ID for reach request.
// To use it call Start, and then make requests with RequestCertificates, concurrent requests are supported.
// This assumes that the returned certificate requester is the only consumer of `receiveC`.
func NewCertificateRequester(sendC chan<- *central.MsgFromSensor,
receiveC <-chan *central.IssueLocalScannerCertsResponse) CertificateRequester {
return &certificateRequesterImpl{
sendC: sendC,
receiveC: receiveC,
}
}

type certificateRequesterImpl struct {
sendC chan<- *central.MsgFromSensor
receiveC <-chan *central.IssueLocalScannerCertsResponse
stopC concurrency.ErrorSignal
requests sync.Map
}

// Start makes the certificate requester listen to `receiveC` and forward responses to any request that is running
// as a call to RequestCertificates.
func (r *certificateRequesterImpl) Start() {
r.stopC.Reset()
go r.dispatchResponses()
}

// Stop makes the certificate stop forwarding responses to running requests. Subsequent calls to RequestCertificates
// will fail with ErrCertificateRequesterStopped.
// Currently active calls to RequestCertificates will continue running until cancelled or timed out via the
// provided context.
func (r *certificateRequesterImpl) Stop() {
r.stopC.Signal()
}

func (r *certificateRequesterImpl) dispatchResponses() {
for {
select {
case <-r.stopC.Done():
return
case msg := <-r.receiveC:
responseC, ok := r.requests.Load(msg.GetRequestId())
if !ok {
log.Debugf("request ID %q does not match any known request ID, dropping response",
msg.GetRequestId())
continue
}
r.requests.Delete(msg.GetRequestId())
// Doesn't block even if the corresponding call to RequestCertificates is cancelled and no one
// ever reads this, because requestC has buffer of 1, and we removed it from `r.request` above,
// in case we get more than 1 response for `msg.GetRequestId()`.
responseC.(chan *central.IssueLocalScannerCertsResponse) <- msg
}
}
}

// RequestCertificates makes a new request for a new set of local scanner certificates from central.
// This assumes the certificate requester is started, otherwise this returns ErrCertificateRequesterStopped.
func (r *certificateRequesterImpl) RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this perhaps assert that Start() has already been called?
Also I think an exported method requires a docstring.

Speaking of docstrings, it would be great to explain why we have this worker goroutine, rather than do the central response handling directly in RequestCertificates. I'm assuming this is because we want to support overlapping requests/response timelines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that the style checks have passed without such docstrings, I'll add that to Start, Stop and RequestCertificates, including that precondition in RequestCertificates.

We need to launch dispatchResponses to support concurrent requests, as we have a single channel to get the responses, which is the interface we have in SensorComponent. We have a unit test TestRequestConcurrentRequestDoNotInterfere for that.
For now we don't plan to have concurrent request for local scanner, because the retry ticker only runs one instance of the tick function at a time, but during this PR we came to the conclusion that supporting concurrent requests would make this type easier to use

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, rather than just checking if Start() has been called via a simple if condition, I would propose to zero-initialize the stopC error signal. Then, you call r.stopC.Reset() at the beginning of Start(), and select on <-r.stopC.Done() in the send method, returning r.stopC.ErrorOrDefault(errors.New("not started")) if that branch is taken.

Re docstrings: Docstrings are actually only reported for functions that are explicitly exported. I.e., func Foo() requires a docstring, as does func (t *ExportedType) Foo(), but func (t *unexportedType) Foo() does not, because this concrete function (*unexportedType).Foo is not exported (OTOH, you can reference ExportedType.Foo even in the absence of a receiver object - it will be of type func(*ExportedType)). The function may be exported as an interface method, but the way Golang interfaces work, it generally isn't possible to rule that such an interface exists, as it might even be declared in a different package. Also, the generated go docs would never show this function anywhere (arguably, the linter should require a docstring on each method of an exported interface, but it doesn't do that).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the docstrings explanation @misberner , TIL.

requestID := uuid.NewV4().String()
receiveC := make(chan *central.IssueLocalScannerCertsResponse, 1)
r.requests.Store(requestID, receiveC)
// Always delete this entry when leaving this scope to account for requests that are never responded, to avoid
// having entries in `r.requests` that are never removed.
defer r.requests.Delete(requestID)

if err := r.send(ctx, requestID); err != nil {
return nil, err
}
return receive(ctx, receiveC)
}

func (r *certificateRequesterImpl) send(ctx context.Context, requestID string) error {
msg := &central.MsgFromSensor{
Msg: &central.MsgFromSensor_IssueLocalScannerCertsRequest{
IssueLocalScannerCertsRequest: &central.IssueLocalScannerCertsRequest{
RequestId: requestID,
},
},
}
select {
case <-r.stopC.Done():
return r.stopC.ErrorWithDefault(ErrCertificateRequesterStopped)
case <-ctx.Done():
return ctx.Err()
case r.sendC <- msg:
return nil
}
}

func receive(ctx context.Context, receiveC <-chan *central.IssueLocalScannerCertsResponse) (*central.IssueLocalScannerCertsResponse, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional to not be a receiver func of certificateRequesterImpl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to stress we are not using any of the state, receive just uses information that is local to a particular request (the context and the receive channel)

select {
case <-ctx.Done():
return nil, ctx.Err()
case response := <-receiveC:
return response, nil
}
}
183 changes: 183 additions & 0 deletions sensor/kubernetes/localscanner/certificate_requester_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package localscanner

import (
"context"
"math/rand"
"sync/atomic"
"testing"
"time"

"github.com/stackrox/rox/generated/internalapi/central"
"github.com/stackrox/rox/pkg/concurrency"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
numConcurrentRequests = 10
)

var (
testTimeout = time.Second
)

func TestCertificateRequesterRequestFailureIfStopped(t *testing.T) {
testCases := map[string]struct {
startRequester bool
}{
"requester not started": {false},
"requester stopped before request": {true},
}
for tcName, tc := range testCases {
t.Run(tcName, func(t *testing.T) {
f := newFixture(0)
defer f.tearDown()
if tc.startRequester {
f.requester.Start()
f.requester.Stop()
}

certs, requestErr := f.requester.RequestCertificates(f.ctx)
assert.Nil(t, certs)
assert.Equal(t, ErrCertificateRequesterStopped, requestErr)
})
}
}

func TestCertificateRequesterRequestCancellation(t *testing.T) {
f := newFixture(0)
f.requester.Start()
defer f.tearDown()

f.cancelCtx()
certs, requestErr := f.requester.RequestCertificates(f.ctx)
assert.Nil(t, certs)
assert.Equal(t, context.Canceled, requestErr)
}

func TestCertificateRequesterRequestSuccess(t *testing.T) {
f := newFixture(0)
f.requester.Start()
defer f.tearDown()

go f.respondRequest(t, 0, nil)

response, err := f.requester.RequestCertificates(f.ctx)
assert.NoError(t, err)
assert.Equal(t, f.interceptedRequestID.Load(), response.GetRequestId())
}

func TestCertificateRequesterResponsesWithUnknownIDAreIgnored(t *testing.T) {
f := newFixture(100 * time.Millisecond)
f.requester.Start()
defer f.tearDown()

// Request with different request ID should be ignored.
go f.respondRequest(t, 0, &central.IssueLocalScannerCertsResponse{RequestId: "UNKNOWN"})

certs, requestErr := f.requester.RequestCertificates(f.ctx)
assert.Nil(t, certs)
assert.Equal(t, context.DeadlineExceeded, requestErr)
}

func TestCertificateRequesterRequestConcurrentRequestDoNotInterfere(t *testing.T) {
testCases := map[string]struct {
responseDelayFunc func(requestIndex int) (responseDelay time.Duration)
}{
"decreasing response delay": {func(requestIndex int) (responseDelay time.Duration) {
// responses are responded increasingly faster, so always out of order.
return time.Duration(numConcurrentRequests-(requestIndex+1)) * 10 * time.Millisecond
}},
"random response delay": {func(requestIndex int) (responseDelay time.Duration) {
// randomly out of order responses.
return time.Duration(rand.Intn(100)) * time.Millisecond
}},
}
for tcName, tc := range testCases {
t.Run(tcName, func(t *testing.T) {
f := newFixture(0)
f.requester.Start()
defer f.tearDown()
waitGroup := concurrency.NewWaitGroup(numConcurrentRequests)

for i := 0; i < numConcurrentRequests; i++ {
i := i
responseDelay := tc.responseDelayFunc(i)
go f.respondRequest(t, responseDelay, nil)
go func() {
defer waitGroup.Add(-1)
_, err := f.requester.RequestCertificates(f.ctx)
assert.NoError(t, err)
}()
}
ok := concurrency.WaitWithTimeout(&waitGroup, time.Duration(numConcurrentRequests)*testTimeout)
require.True(t, ok)
})
}
}

type certificateRequesterFixture struct {
sendC chan *central.MsgFromSensor
receiveC chan *central.IssueLocalScannerCertsResponse
requester CertificateRequester
interceptedRequestID *atomic.Value
ctx context.Context
cancelCtx context.CancelFunc
}

// newFixture creates a new test fixture that uses `timeout` as context timeout if `timeout` is
// not 0, and `testTimeout` otherwise.
func newFixture(timeout time.Duration) *certificateRequesterFixture {
sendC := make(chan *central.MsgFromSensor)
receiveC := make(chan *central.IssueLocalScannerCertsResponse)
requester := NewCertificateRequester(sendC, receiveC)
var interceptedRequestID atomic.Value
if timeout == 0 {
timeout = testTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
return &certificateRequesterFixture{
sendC: sendC,
receiveC: receiveC,
requester: requester,
ctx: ctx,
cancelCtx: cancel,
interceptedRequestID: &interceptedRequestID,
}
}

func (f *certificateRequesterFixture) tearDown() {
f.cancelCtx()
f.requester.Stop()
}

// respondRequest reads a request from `f.sendC` and responds with `responseOverwrite` if not nil, or with
// a response with the same ID as the request otherwise. If `responseDelay` is greater than 0 then this function
// waits for that time before sending the response.
// Before sending the response, it stores in `f.interceptedRequestID` the request ID for the requests read from `f.sendC`.
func (f *certificateRequesterFixture) respondRequest(t *testing.T, responseDelay time.Duration, responseOverwrite *central.IssueLocalScannerCertsResponse) {
select {
case <-f.ctx.Done():
case request := <-f.sendC:
interceptedRequestID := request.GetIssueLocalScannerCertsRequest().GetRequestId()
assert.NotEmpty(t, interceptedRequestID)
var response *central.IssueLocalScannerCertsResponse
if responseOverwrite != nil {
response = responseOverwrite
} else {
response = &central.IssueLocalScannerCertsResponse{RequestId: interceptedRequestID}
}
f.interceptedRequestID.Store(response.GetRequestId())
if responseDelay > 0 {
select {
case <-f.ctx.Done():
return
case <-time.After(responseDelay):
}
}
select {
case <-f.ctx.Done():
case f.receiveC <- response:
}
}
}