-
Notifications
You must be signed in to change notification settings - Fork 171
ROX-9127: Create CertificateRequester to request certificates from central #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ccd298a
10cd58c
308ca2f
0b561e4
35baa5b
d05ff9c
4c159de
e5cfae8
94149b5
83e8c5a
76b916a
5166fb1
a89e594
b477c0e
069beed
c7f5bf1
07f48cf
a7c024d
cb2527e
5fa4ad9
6987961
a7f4600
43ddb58
03b0153
b84558b
3e7eac1
55d164a
35e39a3
8a059f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| package localscanner | ||
|
|
||
| import ( | ||
| "context" | ||
|
|
||
| "github.com/pkg/errors" | ||
| "github.com/stackrox/rox/generated/internalapi/central" | ||
| "github.com/stackrox/rox/pkg/concurrency" | ||
| "github.com/stackrox/rox/pkg/logging" | ||
| "github.com/stackrox/rox/pkg/sync" | ||
| "github.com/stackrox/rox/pkg/uuid" | ||
| ) | ||
|
|
||
| var ( | ||
| // ErrCertificateRequesterStopped is returned by RequestCertificates when the certificate | ||
| // requester is not initialized. | ||
| ErrCertificateRequesterStopped = errors.New("stopped") | ||
| log = logging.LoggerForModule() | ||
| _ CertificateRequester = (*certificateRequesterImpl)(nil) | ||
| ) | ||
|
|
||
| // CertificateRequester requests a new set of local scanner certificates from central. | ||
| type CertificateRequester interface { | ||
| Start() | ||
| Stop() | ||
| RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) | ||
| } | ||
|
|
||
| // NewCertificateRequester creates a new certificate requester that communicates through | ||
| // the specified channels and initializes a new request ID for reach request. | ||
| // To use it call Start, and then make requests with RequestCertificates, concurrent requests are supported. | ||
| // This assumes that the returned certificate requester is the only consumer of `receiveC`. | ||
| func NewCertificateRequester(sendC chan<- *central.MsgFromSensor, | ||
| receiveC <-chan *central.IssueLocalScannerCertsResponse) CertificateRequester { | ||
| return &certificateRequesterImpl{ | ||
| sendC: sendC, | ||
| receiveC: receiveC, | ||
| } | ||
| } | ||
|
|
||
| type certificateRequesterImpl struct { | ||
| sendC chan<- *central.MsgFromSensor | ||
| receiveC <-chan *central.IssueLocalScannerCertsResponse | ||
| stopC concurrency.ErrorSignal | ||
| requests sync.Map | ||
| } | ||
|
|
||
| // Start makes the certificate requester listen to `receiveC` and forward responses to any request that is running | ||
| // as a call to RequestCertificates. | ||
| func (r *certificateRequesterImpl) Start() { | ||
| r.stopC.Reset() | ||
| go r.dispatchResponses() | ||
| } | ||
|
|
||
| // Stop makes the certificate stop forwarding responses to running requests. Subsequent calls to RequestCertificates | ||
| // will fail with ErrCertificateRequesterStopped. | ||
| // Currently active calls to RequestCertificates will continue running until cancelled or timed out via the | ||
| // provided context. | ||
| func (r *certificateRequesterImpl) Stop() { | ||
| r.stopC.Signal() | ||
| } | ||
|
|
||
| func (r *certificateRequesterImpl) dispatchResponses() { | ||
| for { | ||
| select { | ||
| case <-r.stopC.Done(): | ||
| return | ||
| case msg := <-r.receiveC: | ||
| responseC, ok := r.requests.Load(msg.GetRequestId()) | ||
| if !ok { | ||
| log.Debugf("request ID %q does not match any known request ID, dropping response", | ||
| msg.GetRequestId()) | ||
| continue | ||
| } | ||
| r.requests.Delete(msg.GetRequestId()) | ||
| // Doesn't block even if the corresponding call to RequestCertificates is cancelled and no one | ||
| // ever reads this, because requestC has buffer of 1, and we removed it from `r.request` above, | ||
| // in case we get more than 1 response for `msg.GetRequestId()`. | ||
| responseC.(chan *central.IssueLocalScannerCertsResponse) <- msg | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // RequestCertificates makes a new request for a new set of local scanner certificates from central. | ||
| // This assumes the certificate requester is started, otherwise this returns ErrCertificateRequesterStopped. | ||
| func (r *certificateRequesterImpl) RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) { | ||
|
||
| requestID := uuid.NewV4().String() | ||
| receiveC := make(chan *central.IssueLocalScannerCertsResponse, 1) | ||
| r.requests.Store(requestID, receiveC) | ||
| // Always delete this entry when leaving this scope to account for requests that are never responded, to avoid | ||
| // having entries in `r.requests` that are never removed. | ||
| defer r.requests.Delete(requestID) | ||
|
|
||
| if err := r.send(ctx, requestID); err != nil { | ||
| return nil, err | ||
| } | ||
| return receive(ctx, receiveC) | ||
| } | ||
|
|
||
| func (r *certificateRequesterImpl) send(ctx context.Context, requestID string) error { | ||
| msg := ¢ral.MsgFromSensor{ | ||
| Msg: ¢ral.MsgFromSensor_IssueLocalScannerCertsRequest{ | ||
| IssueLocalScannerCertsRequest: ¢ral.IssueLocalScannerCertsRequest{ | ||
| RequestId: requestID, | ||
| }, | ||
| }, | ||
| } | ||
| select { | ||
| case <-r.stopC.Done(): | ||
| return r.stopC.ErrorWithDefault(ErrCertificateRequesterStopped) | ||
| case <-ctx.Done(): | ||
| return ctx.Err() | ||
| case r.sendC <- msg: | ||
| return nil | ||
| } | ||
| } | ||
|
|
||
| func receive(ctx context.Context, receiveC <-chan *central.IssueLocalScannerCertsResponse) (*central.IssueLocalScannerCertsResponse, error) { | ||
|
||
| select { | ||
| case <-ctx.Done(): | ||
| return nil, ctx.Err() | ||
| case response := <-receiveC: | ||
| return response, nil | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| package localscanner | ||
|
|
||
| import ( | ||
| "context" | ||
| "math/rand" | ||
| "sync/atomic" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/stackrox/rox/generated/internalapi/central" | ||
| "github.com/stackrox/rox/pkg/concurrency" | ||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| const ( | ||
| numConcurrentRequests = 10 | ||
| ) | ||
|
|
||
| var ( | ||
| testTimeout = time.Second | ||
| ) | ||
|
|
||
| func TestCertificateRequesterRequestFailureIfStopped(t *testing.T) { | ||
| testCases := map[string]struct { | ||
| startRequester bool | ||
| }{ | ||
| "requester not started": {false}, | ||
| "requester stopped before request": {true}, | ||
| } | ||
| for tcName, tc := range testCases { | ||
| t.Run(tcName, func(t *testing.T) { | ||
| f := newFixture(0) | ||
| defer f.tearDown() | ||
| if tc.startRequester { | ||
| f.requester.Start() | ||
| f.requester.Stop() | ||
| } | ||
|
|
||
| certs, requestErr := f.requester.RequestCertificates(f.ctx) | ||
| assert.Nil(t, certs) | ||
| assert.Equal(t, ErrCertificateRequesterStopped, requestErr) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| func TestCertificateRequesterRequestCancellation(t *testing.T) { | ||
| f := newFixture(0) | ||
| f.requester.Start() | ||
| defer f.tearDown() | ||
|
|
||
| f.cancelCtx() | ||
| certs, requestErr := f.requester.RequestCertificates(f.ctx) | ||
| assert.Nil(t, certs) | ||
| assert.Equal(t, context.Canceled, requestErr) | ||
| } | ||
|
|
||
| func TestCertificateRequesterRequestSuccess(t *testing.T) { | ||
| f := newFixture(0) | ||
| f.requester.Start() | ||
| defer f.tearDown() | ||
|
|
||
| go f.respondRequest(t, 0, nil) | ||
|
|
||
| response, err := f.requester.RequestCertificates(f.ctx) | ||
| assert.NoError(t, err) | ||
| assert.Equal(t, f.interceptedRequestID.Load(), response.GetRequestId()) | ||
| } | ||
|
|
||
| func TestCertificateRequesterResponsesWithUnknownIDAreIgnored(t *testing.T) { | ||
| f := newFixture(100 * time.Millisecond) | ||
| f.requester.Start() | ||
| defer f.tearDown() | ||
|
|
||
| // Request with different request ID should be ignored. | ||
| go f.respondRequest(t, 0, ¢ral.IssueLocalScannerCertsResponse{RequestId: "UNKNOWN"}) | ||
|
|
||
| certs, requestErr := f.requester.RequestCertificates(f.ctx) | ||
| assert.Nil(t, certs) | ||
| assert.Equal(t, context.DeadlineExceeded, requestErr) | ||
| } | ||
|
|
||
| func TestCertificateRequesterRequestConcurrentRequestDoNotInterfere(t *testing.T) { | ||
| testCases := map[string]struct { | ||
| responseDelayFunc func(requestIndex int) (responseDelay time.Duration) | ||
| }{ | ||
| "decreasing response delay": {func(requestIndex int) (responseDelay time.Duration) { | ||
| // responses are responded increasingly faster, so always out of order. | ||
| return time.Duration(numConcurrentRequests-(requestIndex+1)) * 10 * time.Millisecond | ||
| }}, | ||
| "random response delay": {func(requestIndex int) (responseDelay time.Duration) { | ||
| // randomly out of order responses. | ||
| return time.Duration(rand.Intn(100)) * time.Millisecond | ||
| }}, | ||
| } | ||
| for tcName, tc := range testCases { | ||
| t.Run(tcName, func(t *testing.T) { | ||
| f := newFixture(0) | ||
| f.requester.Start() | ||
| defer f.tearDown() | ||
| waitGroup := concurrency.NewWaitGroup(numConcurrentRequests) | ||
|
|
||
| for i := 0; i < numConcurrentRequests; i++ { | ||
| i := i | ||
| responseDelay := tc.responseDelayFunc(i) | ||
| go f.respondRequest(t, responseDelay, nil) | ||
| go func() { | ||
| defer waitGroup.Add(-1) | ||
| _, err := f.requester.RequestCertificates(f.ctx) | ||
| assert.NoError(t, err) | ||
| }() | ||
| } | ||
| ok := concurrency.WaitWithTimeout(&waitGroup, time.Duration(numConcurrentRequests)*testTimeout) | ||
| require.True(t, ok) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| type certificateRequesterFixture struct { | ||
| sendC chan *central.MsgFromSensor | ||
| receiveC chan *central.IssueLocalScannerCertsResponse | ||
| requester CertificateRequester | ||
| interceptedRequestID *atomic.Value | ||
| ctx context.Context | ||
| cancelCtx context.CancelFunc | ||
| } | ||
|
|
||
| // newFixture creates a new test fixture that uses `timeout` as context timeout if `timeout` is | ||
| // not 0, and `testTimeout` otherwise. | ||
| func newFixture(timeout time.Duration) *certificateRequesterFixture { | ||
| sendC := make(chan *central.MsgFromSensor) | ||
| receiveC := make(chan *central.IssueLocalScannerCertsResponse) | ||
| requester := NewCertificateRequester(sendC, receiveC) | ||
| var interceptedRequestID atomic.Value | ||
| if timeout == 0 { | ||
| timeout = testTimeout | ||
| } | ||
| ctx, cancel := context.WithTimeout(context.Background(), timeout) | ||
| return &certificateRequesterFixture{ | ||
| sendC: sendC, | ||
| receiveC: receiveC, | ||
| requester: requester, | ||
| ctx: ctx, | ||
| cancelCtx: cancel, | ||
| interceptedRequestID: &interceptedRequestID, | ||
| } | ||
| } | ||
|
|
||
| func (f *certificateRequesterFixture) tearDown() { | ||
| f.cancelCtx() | ||
| f.requester.Stop() | ||
| } | ||
|
|
||
| // respondRequest reads a request from `f.sendC` and responds with `responseOverwrite` if not nil, or with | ||
| // a response with the same ID as the request otherwise. If `responseDelay` is greater than 0 then this function | ||
| // waits for that time before sending the response. | ||
| // Before sending the response, it stores in `f.interceptedRequestID` the request ID for the requests read from `f.sendC`. | ||
| func (f *certificateRequesterFixture) respondRequest(t *testing.T, responseDelay time.Duration, responseOverwrite *central.IssueLocalScannerCertsResponse) { | ||
| select { | ||
| case <-f.ctx.Done(): | ||
| case request := <-f.sendC: | ||
| interceptedRequestID := request.GetIssueLocalScannerCertsRequest().GetRequestId() | ||
| assert.NotEmpty(t, interceptedRequestID) | ||
| var response *central.IssueLocalScannerCertsResponse | ||
| if responseOverwrite != nil { | ||
| response = responseOverwrite | ||
| } else { | ||
| response = ¢ral.IssueLocalScannerCertsResponse{RequestId: interceptedRequestID} | ||
| } | ||
| f.interceptedRequestID.Store(response.GetRequestId()) | ||
| if responseDelay > 0 { | ||
| select { | ||
| case <-f.ctx.Done(): | ||
| return | ||
| case <-time.After(responseDelay): | ||
| } | ||
| } | ||
| select { | ||
| case <-f.ctx.Done(): | ||
| case f.receiveC <- response: | ||
| } | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.