From 681334b4f79462320f6a7dadc9ee61853a2fe2b1 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 20 Jan 2022 18:43:11 +0100 Subject: [PATCH 01/34] First version of RetryableSourceRetriever Type to make requests to a retryable source that has an asynchronous interface, with timeout per request, and configurable backoff --- pkg/retry/retry_source.go | 95 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 pkg/retry/retry_source.go diff --git a/pkg/retry/retry_source.go b/pkg/retry/retry_source.go new file mode 100644 index 0000000000000..cb7324bac788a --- /dev/null +++ b/pkg/retry/retry_source.go @@ -0,0 +1,95 @@ +package retry + +import ( + "context" + "time" + + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/util/wait" +) + +// RetryableSource is a value that allows asking for a result, and returns the +// corresponding result asynchronously. +// Clients only care about the first value returned in ResultC(). +// AskForResult() can be called several times to retry the result computation, the +// RetryableSource is in charge of handling the cancellation of the computation if needed. +type RetryableSource interface { + AskForResult() + ResultC() chan *Result +} + +// Result wraps a pair (result, err) produced by a source. By convention +// either err or v has the zero value of its type. +type Result struct { + v interface {} + err error +} + +// NewRetryableSourceRetriever create a new NewRetryableSourceRetriever +func NewRetryableSourceRetriever(backoff wait.Backoff, requestTimeout time.Duration) *RetryableSourceRetriever { + return &RetryableSourceRetriever{ + RequestTimeout: requestTimeout, + Backoff: backoff, + } +} + +type RetryableSourceRetriever struct { + // time to consider failed a call to AskForResult() that didn't return a result yet. + RequestTimeout time.Duration + ErrReporter func (err error) + // should be reset between calls to Run. + Backoff wait.Backoff + timeoutC chan struct{} + timeoutTimer *time.Timer +} + +// Run gets the result from the specified source. +// Any timeout in ctx is respected. +func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSource) (interface{}, error) { + r.timeoutC = make(chan struct{}) + + source.AskForResult() + r.setTimeoutTimer(r.RequestTimeout) + defer r.setTimeoutTimer(-1) + for { + select { + case <-ctx.Done(): + return nil, errors.New("request cancelled") + case <-r.timeoutC: + // assume result will never come. + r.handleError(errors.New("timeout"), source) + case result := <- source.ResultC(): + err := result.err + if err != nil { + r.handleError(err, source) + } else { + return result.v, nil + } + } + } +} + +func (r *RetryableSourceRetriever) handleError(err error, source RetryableSource) { + if r.ErrReporter != nil { + r.ErrReporter(err) + } + r.setTimeoutTimer(-1) + time.AfterFunc(r.Backoff.Step(), func() { + source.AskForResult() + r.setTimeoutTimer(r.RequestTimeout) + }) +} + +// use negative timeout to just stop the timer. +func (r *RetryableSourceRetriever) setTimeoutTimer(timeout time.Duration) { + if r.timeoutTimer != nil { + r.timeoutTimer.Stop() + } + if timeout >= 0 { + r.timeoutTimer = time.AfterFunc(timeout, func() { + r.timeoutC <- struct{}{} + }) + } else { + r.timeoutTimer = nil + } +} From b801d57f4f1a55b02a5ad00741eea7d41cb68adb Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 20 Jan 2022 18:47:18 +0100 Subject: [PATCH 02/34] fix style --- pkg/retry/retry_source.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pkg/retry/retry_source.go b/pkg/retry/retry_source.go index cb7324bac788a..80aea01d1ea3f 100644 --- a/pkg/retry/retry_source.go +++ b/pkg/retry/retry_source.go @@ -21,28 +21,29 @@ type RetryableSource interface { // Result wraps a pair (result, err) produced by a source. By convention // either err or v has the zero value of its type. type Result struct { - v interface {} + v interface{} err error } -// NewRetryableSourceRetriever create a new NewRetryableSourceRetriever -func NewRetryableSourceRetriever(backoff wait.Backoff, requestTimeout time.Duration) *RetryableSourceRetriever { - return &RetryableSourceRetriever{ - RequestTimeout: requestTimeout, - Backoff: backoff, - } -} - +// RetryableSourceRetriever be used to retrieve the result in a RetryableSource. type RetryableSourceRetriever struct { // time to consider failed a call to AskForResult() that didn't return a result yet. RequestTimeout time.Duration - ErrReporter func (err error) + ErrReporter func(err error) // should be reset between calls to Run. - Backoff wait.Backoff - timeoutC chan struct{} + Backoff wait.Backoff + timeoutC chan struct{} timeoutTimer *time.Timer } +// NewRetryableSourceRetriever create a new NewRetryableSourceRetriever +func NewRetryableSourceRetriever(backoff wait.Backoff, requestTimeout time.Duration) *RetryableSourceRetriever { + return &RetryableSourceRetriever{ + RequestTimeout: requestTimeout, + Backoff: backoff, + } +} + // Run gets the result from the specified source. // Any timeout in ctx is respected. func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSource) (interface{}, error) { @@ -58,7 +59,7 @@ func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSour case <-r.timeoutC: // assume result will never come. r.handleError(errors.New("timeout"), source) - case result := <- source.ResultC(): + case result := <-source.ResultC(): err := result.err if err != nil { r.handleError(err, source) From b8def21079a212c4ba196c8a7d387d679ad3cda5 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Fri, 21 Jan 2022 12:47:30 +0100 Subject: [PATCH 03/34] add error handler and validator --- pkg/retry/retry_source.go | 40 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/pkg/retry/retry_source.go b/pkg/retry/retry_source.go index 80aea01d1ea3f..62ac1ead862d6 100644 --- a/pkg/retry/retry_source.go +++ b/pkg/retry/retry_source.go @@ -8,14 +8,16 @@ import ( "k8s.io/apimachinery/pkg/util/wait" ) -// RetryableSource is a value that allows asking for a result, and returns the -// corresponding result asynchronously. -// Clients only care about the first value returned in ResultC(). -// AskForResult() can be called several times to retry the result computation, the +// RetryableSource is a proxy with an object that is able to compute a result, but +// that might forget our request, or return an error result, and that returns the +// result asynchronously. +// AskForResult() can be called to request a result, that should be make available in the +// returned channel. Each time AskForResult() is called the previously returned channel is abandoned. +// Retry() can be called several times to retry the result computation, the // RetryableSource is in charge of handling the cancellation of the computation if needed. type RetryableSource interface { - AskForResult() - ResultC() chan *Result + AskForResult() chan *Result + Retry() } // Result wraps a pair (result, err) produced by a source. By convention @@ -29,7 +31,11 @@ type Result struct { type RetryableSourceRetriever struct { // time to consider failed a call to AskForResult() that didn't return a result yet. RequestTimeout time.Duration - ErrReporter func(err error) + // optionally specify a function to invoke on each error. waitDuration is the time until + // the next retry. + OnError func(err error, timeToNextRetry time.Duration) + // optionally specify a validation function for each result. + ValidateResult func(interface{}) bool // should be reset between calls to Run. Backoff wait.Backoff timeoutC chan struct{} @@ -49,7 +55,7 @@ func NewRetryableSourceRetriever(backoff wait.Backoff, requestTimeout time.Durat func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSource) (interface{}, error) { r.timeoutC = make(chan struct{}) - source.AskForResult() + resultC := source.AskForResult() r.setTimeoutTimer(r.RequestTimeout) defer r.setTimeoutTimer(-1) for { @@ -59,24 +65,30 @@ func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSour case <-r.timeoutC: // assume result will never come. r.handleError(errors.New("timeout"), source) - case result := <-source.ResultC(): + case result := <-resultC: err := result.err if err != nil { r.handleError(err, source) } else { - return result.v, nil + if r.ValidateResult != nil && !r.ValidateResult(result.v) { + err := errors.Errorf("validation failed for value %v", result.v) + r.handleError(err, source) + } else { + return result.v, nil + } } } } } func (r *RetryableSourceRetriever) handleError(err error, source RetryableSource) { - if r.ErrReporter != nil { - r.ErrReporter(err) + waitDuration := r.Backoff.Step() + if r.OnError != nil { + r.OnError(err, waitDuration) } r.setTimeoutTimer(-1) - time.AfterFunc(r.Backoff.Step(), func() { - source.AskForResult() + time.AfterFunc(waitDuration, func() { + source.Retry() r.setTimeoutTimer(r.RequestTimeout) }) } From b35ff30c8aeb7f41ac4b7598809c8b945d1dbce6 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Fri, 21 Jan 2022 12:50:54 +0100 Subject: [PATCH 04/34] Pass context on AskForResult --- pkg/retry/retry_source.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/retry/retry_source.go b/pkg/retry/retry_source.go index 62ac1ead862d6..f85440eeaa158 100644 --- a/pkg/retry/retry_source.go +++ b/pkg/retry/retry_source.go @@ -16,7 +16,7 @@ import ( // Retry() can be called several times to retry the result computation, the // RetryableSource is in charge of handling the cancellation of the computation if needed. type RetryableSource interface { - AskForResult() chan *Result + AskForResult(ctx context.Context) chan *Result Retry() } @@ -55,7 +55,7 @@ func NewRetryableSourceRetriever(backoff wait.Backoff, requestTimeout time.Durat func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSource) (interface{}, error) { r.timeoutC = make(chan struct{}) - resultC := source.AskForResult() + resultC := source.AskForResult(ctx) r.setTimeoutTimer(r.RequestTimeout) defer r.setTimeoutTimer(-1) for { From de7d6f2117b878603d8aeea3f6859dbbe3358737 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 18 Jan 2022 18:57:57 +0100 Subject: [PATCH 05/34] Draft CertManager --- .../kubernetes/certificates/cert_manager.go | 303 ++++++++++++++++++ .../certificates/cert_manager_test.go | 125 ++++++++ 2 files changed, 428 insertions(+) create mode 100644 sensor/kubernetes/certificates/cert_manager.go create mode 100644 sensor/kubernetes/certificates/cert_manager_test.go diff --git a/sensor/kubernetes/certificates/cert_manager.go b/sensor/kubernetes/certificates/cert_manager.go new file mode 100644 index 0000000000000..d37dc0c400c03 --- /dev/null +++ b/sensor/kubernetes/certificates/cert_manager.go @@ -0,0 +1,303 @@ +package certificates + +import ( + "context" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stackrox/rox/pkg/logging" + v1 "k8s.io/api/core/v1" + k8sErrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/util/retry" +) + +const ( + // FIXME adjust + internalChannelBuffSize = 50 + defaultCentralRequestTimeout = time.Minute +) + +var ( + log = logging.LoggerForModule() + // FIXME adjust + k8sAPIBackoff = retry.DefaultBackoff + _ SecretsExpirationStrategy = (*secretsExpirationStrategyImpl)(nil) + + _ CertManager = (*certManagerImpl)(nil) +) + +type SecretsExpirationStrategy interface { + GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration +} + +// CertManager is in charge of storing and refreshing service TLS certificates in a set of k8s secrets. +type CertManager interface { + Start(ctx context.Context) error + Stop() + // HandleIssueCertificatesResponse handles a certificate response from central. + // - Precondition: if issueError is nil then certificates is not nil. + // - Implementations should handle a nil receiver like an unknown request ID. + HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error +} + +type CertIssuanceFunc func(CertManager) (requestID string, err error) +type certManagerImpl struct { + // should be kept constant. + secretNames map[storage.ServiceType]string + secretsClient corev1.SecretInterface + issueCerts CertIssuanceFunc + stopC concurrency.ErrorSignal + centralRequestTimeout time.Duration + centralBackoffProto wait.Backoff + secretExpiration SecretsExpirationStrategy + // set at Start(). + ctx context.Context + // handled by loop goroutine. + dispatchC chan interface{} + requestStatus *requestStatus + refreshTimer *time.Timer + certIssueRequestTimeoutTimer *time.Timer +} + +type requestStatus struct { + requestID string + backoff wait.Backoff +} + +func NewCertManager(secretsClient corev1.SecretInterface, secretNames map[storage.ServiceType]string, + centralBackoff wait.Backoff, issueCerts CertIssuanceFunc) CertManager { + return newCertManager(secretsClient, secretNames, centralBackoff, issueCerts) +} + +func newCertManager(secretsClient corev1.SecretInterface, secretNames map[storage.ServiceType]string, + centralBackoff wait.Backoff, issueCerts CertIssuanceFunc) *certManagerImpl { + return &certManagerImpl{ + secretNames: secretNames, + secretsClient: secretsClient, + issueCerts: issueCerts, + stopC: concurrency.NewErrorSignal(), + centralRequestTimeout: defaultCentralRequestTimeout, + centralBackoffProto: centralBackoff, + secretExpiration: &secretsExpirationStrategyImpl{}, + dispatchC: make(chan interface{}, internalChannelBuffSize), + requestStatus: &requestStatus{}, + } +} + +func (c *certManagerImpl) Start(ctx context.Context) error { + c.ctx = ctx + secrets, err := c.fetchSecrets() + if err != nil { + return errors.Wrapf(err, "fetching secrets %v", c.secretNames) + } + // this refreshes immediately if certificates are already expired. + c.scheduleIssueCertificatesRefresh(c.secretExpiration.GetSecretsDuration(secrets)) + + go c.loop() + + return nil +} + +func (c *certManagerImpl) Stop() { + c.stopC.Signal() +} + +func (c *certManagerImpl) issueCertificates() (requestID string, err error){ + return c.issueCerts(c) +} + +func (c *certManagerImpl) loop() { + // FIXME: protect private methods and fields + for { + select { + case msg := <- c.dispatchC: + switch m := msg.(type) { + case requestCertificates: + c.requestCertificates() + case handleIssueCertificatesResponse: + c.doHandleIssueCertificatesResponse(m.requestID, m.issueError, m.certificates) + case issueCertificatesTimeout: + c.issueCertificatesTimeout(m.requestID) + default: + log.Errorf("received unknown message %v, message will be ignored", msg) + } + + case <-c.stopC.Done(): + c.doStop() + return + } + } +} + +type handleIssueCertificatesResponse struct { + requestID string + issueError error + certificates *storage.TypedServiceCertificateSet +} + +type requestCertificates struct {} + +type issueCertificatesTimeout struct { + requestID string +} + +func (c *certManagerImpl) setRefreshTimer(timer *time.Timer){ + if c.refreshTimer != nil { + c.refreshTimer.Stop() + } + c.refreshTimer = timer +} + +func (c *certManagerImpl) setCertIssueRequestTimeoutTimer(timer *time.Timer){ + if c.certIssueRequestTimeoutTimer != nil { + c.certIssueRequestTimeoutTimer.Stop() + } + c.certIssueRequestTimeoutTimer = timer +} + +// set request id, and reset timers and retry backoff. +func (c *certManagerImpl) setRequestId(requestID string) { + c.requestStatus.requestID = requestID + c.requestStatus.backoff = c.centralBackoffProto + c.setRefreshTimer(nil) + c.setCertIssueRequestTimeoutTimer(nil) +} + + +func (c *certManagerImpl) HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error { + if c == nil { + return errors.Errorf("unknown request ID %s, potentially due to request timeout", requestID) + } + c.dispatchC <- handleIssueCertificatesResponse{requestID: requestID, issueError: issueError, certificates: certificates} + return nil +} + +// should only be called from the loop goroutine. +func (c *certManagerImpl) requestCertificates() { + if requestID, err := c.issueCertificates(); err != nil { + // client side error + log.Errorf("client error sending request to issue certificates for secrets %v: %s", + c.secretNames, err) + c.scheduleRetryIssueCertificatesRefresh() + } else { + c.setRequestId(requestID) + c.setCertIssueRequestTimeoutTimer(time.AfterFunc(c.centralRequestTimeout, func() { + c.dispatchC <- issueCertificatesTimeout{requestID: requestID} + })) + } +} + +// should only be called from the loop goroutine. +func (c *certManagerImpl) doHandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) { + if requestID != c.requestStatus.requestID { + // silently ignore responses sent to the wrong CertManager. + log.Debugf("ignoring issue certificate response from unknown request id %q", requestID) + return + } + c.setRequestId("") + + if issueError != nil { + // server side error. + log.Errorf("server side error issuing certificates for secrets %v: %s", c.secretNames, issueError) + c.scheduleRetryIssueCertificatesRefresh() + return + } + + nextTimeToRefresh, refreshErr := c.refreshSecrets(certificates) + if refreshErr != nil { + log.Errorf("failure to store the new certificates in the secrets %v: %s", c.secretNames, refreshErr) + c.scheduleRetryIssueCertificatesRefresh() + return + } + + log.Infof("successfully refreshed credential in secrets %v", c.secretNames) + c.scheduleIssueCertificatesRefresh(nextTimeToRefresh) +} + +// should only be called from the loop goroutine. +func (c *certManagerImpl) issueCertificatesTimeout(requestID string) { + if requestID != c.requestStatus.requestID { + // this is a timeout for a request we don't care about anymore. + log.Debugf("ignoring timeout on issue certificate request from unknown request id %q", requestID) + return + } + log.Errorf("timeout waiting for certificates for secrets %v on request with id %q after waiting for %s", + c.secretNames, requestID, c.centralRequestTimeout) + // ignore eventual responses for this request. + c.setRequestId("") + c.scheduleRetryIssueCertificatesRefresh() +} + +// should only be called from the loop goroutine. +func (c *certManagerImpl) doStop() { + c.setRequestId("") + log.Info("CertManager stopped.") +} + +func (c *certManagerImpl) scheduleRetryIssueCertificatesRefresh() { + c.scheduleIssueCertificatesRefresh(c.requestStatus.backoff.Step()) +} + +func (c *certManagerImpl) scheduleIssueCertificatesRefresh(timeToRefresh time.Duration) { + log.Infof("certificates for secrets %v scheduled to be refreshed in %s", c.secretNames, timeToRefresh) + c.setRefreshTimer(time.AfterFunc(timeToRefresh, func() { + c.dispatchC <- requestCertificates{} + })) +} + +func (c *certManagerImpl) fetchSecrets() (map[storage.ServiceType]*v1.Secret, error) { + secretsMap := make(map[storage.ServiceType]*v1.Secret, len(c.secretNames)) + var fetchErr error + for serviceType, secretName := range c.secretNames { + var ( + secret *v1.Secret + err error + ) + retryErr := retry.OnError(k8sAPIBackoff, + func(err error) bool { + return !k8sErrors.IsNotFound(err) + }, + func() error { + secret, err = c.secretsClient.Get(c.ctx, secretName, metav1.GetOptions{}) + return err + }, + ) + if retryErr != nil{ + fetchErr = multierror.Append(fetchErr, errors.Wrapf(retryErr,"for secret %s", secretName)) + } else { + secretsMap[serviceType] = secret + } + } + + if fetchErr != nil { + return nil, fetchErr + } + return secretsMap, nil +} + +// Performs retries for reads and writes with the k8s API. +// On success, it returns the duration after which the secrets should be refreshed. +func (c *certManagerImpl) refreshSecrets(certificates *storage.TypedServiceCertificateSet) (time.Duration, error) { + secrets, err := c.fetchSecrets() + if err != nil { + // FIXME wrap + return 0, err + } + // TODO update secrets ROX-8969 + + return c.secretExpiration.GetSecretsDuration(secrets), nil +} + +type secretsExpirationStrategyImpl struct {} + +func (s *secretsExpirationStrategyImpl) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration { + // TODO ROX-8969 + return 5 * time.Second +} + diff --git a/sensor/kubernetes/certificates/cert_manager_test.go b/sensor/kubernetes/certificates/cert_manager_test.go new file mode 100644 index 0000000000000..5fd8780b03786 --- /dev/null +++ b/sensor/kubernetes/certificates/cert_manager_test.go @@ -0,0 +1,125 @@ +package certificates + +import ( + "context" + "testing" + "time" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stretchr/testify/suite" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes/fake" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +const ( + namespace = "namespace" + requestID = "requestID" +) +var ( + centralBackoff = wait.Backoff{ + Steps: 3, + Duration: 10 * time.Millisecond, + Factor: 10.0, + Jitter: 0.1, + Cap: 2 * time.Second, + } +) + +func TestHandler(t *testing.T) { + suite.Run(t, new(certManagerSuite)) +} + +type certManagerSuite struct { + suite.Suite +} + +func fakeClientSet(secretNames ...string) *fake.Clientset{ + secrets := make([]runtime.Object, len(secretNames)) + for i, secretName := range secretNames { + secrets[i] = &v1.Secret{ObjectMeta: metav1.ObjectMeta{Name: secretName, Namespace: namespace}} + } + return fake.NewSimpleClientset(secrets...) +} + +func fakeSecretsClient(secretNames ...string) corev1.SecretInterface { + return fakeClientSet(secretNames...).CoreV1().Secrets(namespace) +} + +type fixedSecretsExpirationStrategy struct { + durations []time.Duration + invocations int + signal concurrency.ErrorSignal +} + +func newFixedSecretsExpirationStrategy(durations ...time.Duration) *fixedSecretsExpirationStrategy{ + return &fixedSecretsExpirationStrategy{ + durations: durations, + signal: concurrency.NewErrorSignal(), + } +} + +// signals .signal when the last timeout is reached +func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) (duration time.Duration) { + s.invocations ++ + if len(s.durations) <= 1{ + s.signal.Signal() + return s.durations[0] + } + + duration, s.durations = s.durations[0], s.durations[1:] + return duration +} + + +func (s *certManagerSuite) TestSuccessfulRefresh() { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + secretName := "foo" + secretNames := map[storage.ServiceType]string{ + storage.ServiceType_SCANNER_DB_SERVICE: secretName, + } + secretsClient := fakeSecretsClient(secretName) + certManager := newCertManager(secretsClient, secretNames, centralBackoff, + func(manager CertManager) (requestID string, err error) { + // FIXME nil certs + manager.HandleIssueCertificatesResponse(requestID, nil, nil) + return requestID, nil + }) + certManager.centralRequestTimeout = 2 * time.Second + expirationStrategy := newFixedSecretsExpirationStrategy(0, 2 * time.Second) + certManager.secretExpiration = expirationStrategy + + certManager.Start(ctx) + defer certManager.Stop() + + // FIXME: idea, add error handler fields to certManagerImpl, that processes errors in loop, and + // make the processing functions return err. For prod it just logs; for test it keeps a slice of + // errors or something we can inspect + + waitErr, ok := expirationStrategy.signal.WaitUntil(ctx) + s.Require().True(ok) + s.NoError(waitErr) + + // TODO assert certManager not stopped + + s.Empty(certManager.requestStatus.requestID) +} + + +/* +TODO failures: + +- success +- server failure +- client failure +- timeout + +in all check retries as expected +*/ \ No newline at end of file From a8019a0a2d61ff4aa5ca5732b21b2aca75d1e7c9 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 19 Jan 2022 11:59:03 +0100 Subject: [PATCH 06/34] Remove references to central, and checkstyle --- .../kubernetes/certificates/cert_manager.go | 102 +++++++++--------- .../certificates/cert_manager_test.go | 37 ++++--- 2 files changed, 69 insertions(+), 70 deletions(-) diff --git a/sensor/kubernetes/certificates/cert_manager.go b/sensor/kubernetes/certificates/cert_manager.go index d37dc0c400c03..9df42f718d115 100644 --- a/sensor/kubernetes/certificates/cert_manager.go +++ b/sensor/kubernetes/certificates/cert_manager.go @@ -19,20 +19,20 @@ import ( const ( // FIXME adjust - internalChannelBuffSize = 50 - defaultCentralRequestTimeout = time.Minute + internalChannelBuffSize = 50 + defaultCertRequestTimeout = time.Minute ) var ( log = logging.LoggerForModule() // FIXME adjust - k8sAPIBackoff = retry.DefaultBackoff - _ SecretsExpirationStrategy = (*secretsExpirationStrategyImpl)(nil) + k8sAPIBackoff = retry.DefaultBackoff + _ secretsExpirationStrategy = (*secretsExpirationStrategyImpl)(nil) _ CertManager = (*certManagerImpl)(nil) ) -type SecretsExpirationStrategy interface { +type secretsExpirationStrategy interface { GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration } @@ -40,53 +40,55 @@ type SecretsExpirationStrategy interface { type CertManager interface { Start(ctx context.Context) error Stop() - // HandleIssueCertificatesResponse handles a certificate response from central. + // HandleIssueCertificatesResponse handles a certificate issue response. // - Precondition: if issueError is nil then certificates is not nil. // - Implementations should handle a nil receiver like an unknown request ID. HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error } +// CertIssuanceFunc can be used to request a certificate. type CertIssuanceFunc func(CertManager) (requestID string, err error) type certManagerImpl struct { // should be kept constant. - secretNames map[storage.ServiceType]string - secretsClient corev1.SecretInterface - issueCerts CertIssuanceFunc - stopC concurrency.ErrorSignal - centralRequestTimeout time.Duration - centralBackoffProto wait.Backoff - secretExpiration SecretsExpirationStrategy + secretNames map[storage.ServiceType]string + secretsClient corev1.SecretInterface + issueCerts CertIssuanceFunc + stopC concurrency.ErrorSignal + certRequestTimeout time.Duration + certRequestBackoffProto wait.Backoff + secretExpiration secretsExpirationStrategy // set at Start(). - ctx context.Context + ctx context.Context // handled by loop goroutine. - dispatchC chan interface{} - requestStatus *requestStatus - refreshTimer *time.Timer + dispatchC chan interface{} + requestStatus *requestStatus + refreshTimer *time.Timer certIssueRequestTimeoutTimer *time.Timer } type requestStatus struct { requestID string - backoff wait.Backoff + backoff wait.Backoff } +// NewCertManager creates a new CertManager. func NewCertManager(secretsClient corev1.SecretInterface, secretNames map[storage.ServiceType]string, - centralBackoff wait.Backoff, issueCerts CertIssuanceFunc) CertManager { - return newCertManager(secretsClient, secretNames, centralBackoff, issueCerts) + certRequestBackoff wait.Backoff, issueCerts CertIssuanceFunc) CertManager { + return newCertManager(secretsClient, secretNames, certRequestBackoff, issueCerts) } func newCertManager(secretsClient corev1.SecretInterface, secretNames map[storage.ServiceType]string, - centralBackoff wait.Backoff, issueCerts CertIssuanceFunc) *certManagerImpl { + certRequestBackoff wait.Backoff, issueCerts CertIssuanceFunc) *certManagerImpl { return &certManagerImpl{ - secretNames: secretNames, - secretsClient: secretsClient, - issueCerts: issueCerts, - stopC: concurrency.NewErrorSignal(), - centralRequestTimeout: defaultCentralRequestTimeout, - centralBackoffProto: centralBackoff, - secretExpiration: &secretsExpirationStrategyImpl{}, - dispatchC: make(chan interface{}, internalChannelBuffSize), - requestStatus: &requestStatus{}, + secretNames: secretNames, + secretsClient: secretsClient, + issueCerts: issueCerts, + stopC: concurrency.NewErrorSignal(), + certRequestTimeout: defaultCertRequestTimeout, + certRequestBackoffProto: certRequestBackoff, + secretExpiration: &secretsExpirationStrategyImpl{}, + dispatchC: make(chan interface{}, internalChannelBuffSize), + requestStatus: &requestStatus{}, } } @@ -108,7 +110,7 @@ func (c *certManagerImpl) Stop() { c.stopC.Signal() } -func (c *certManagerImpl) issueCertificates() (requestID string, err error){ +func (c *certManagerImpl) issueCertificates() (requestID string, err error) { return c.issueCerts(c) } @@ -116,7 +118,7 @@ func (c *certManagerImpl) loop() { // FIXME: protect private methods and fields for { select { - case msg := <- c.dispatchC: + case msg := <-c.dispatchC: switch m := msg.(type) { case requestCertificates: c.requestCertificates() @@ -136,25 +138,25 @@ func (c *certManagerImpl) loop() { } type handleIssueCertificatesResponse struct { - requestID string - issueError error + requestID string + issueError error certificates *storage.TypedServiceCertificateSet } -type requestCertificates struct {} +type requestCertificates struct{} type issueCertificatesTimeout struct { requestID string } -func (c *certManagerImpl) setRefreshTimer(timer *time.Timer){ +func (c *certManagerImpl) setRefreshTimer(timer *time.Timer) { if c.refreshTimer != nil { c.refreshTimer.Stop() } c.refreshTimer = timer } -func (c *certManagerImpl) setCertIssueRequestTimeoutTimer(timer *time.Timer){ +func (c *certManagerImpl) setCertIssueRequestTimeoutTimer(timer *time.Timer) { if c.certIssueRequestTimeoutTimer != nil { c.certIssueRequestTimeoutTimer.Stop() } @@ -162,14 +164,13 @@ func (c *certManagerImpl) setCertIssueRequestTimeoutTimer(timer *time.Timer){ } // set request id, and reset timers and retry backoff. -func (c *certManagerImpl) setRequestId(requestID string) { +func (c *certManagerImpl) setRequestID(requestID string) { c.requestStatus.requestID = requestID - c.requestStatus.backoff = c.centralBackoffProto + c.requestStatus.backoff = c.certRequestBackoffProto c.setRefreshTimer(nil) c.setCertIssueRequestTimeoutTimer(nil) } - func (c *certManagerImpl) HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error { if c == nil { return errors.Errorf("unknown request ID %s, potentially due to request timeout", requestID) @@ -186,8 +187,8 @@ func (c *certManagerImpl) requestCertificates() { c.secretNames, err) c.scheduleRetryIssueCertificatesRefresh() } else { - c.setRequestId(requestID) - c.setCertIssueRequestTimeoutTimer(time.AfterFunc(c.centralRequestTimeout, func() { + c.setRequestID(requestID) + c.setCertIssueRequestTimeoutTimer(time.AfterFunc(c.certRequestTimeout, func() { c.dispatchC <- issueCertificatesTimeout{requestID: requestID} })) } @@ -200,7 +201,7 @@ func (c *certManagerImpl) doHandleIssueCertificatesResponse(requestID string, is log.Debugf("ignoring issue certificate response from unknown request id %q", requestID) return } - c.setRequestId("") + c.setRequestID("") if issueError != nil { // server side error. @@ -228,15 +229,15 @@ func (c *certManagerImpl) issueCertificatesTimeout(requestID string) { return } log.Errorf("timeout waiting for certificates for secrets %v on request with id %q after waiting for %s", - c.secretNames, requestID, c.centralRequestTimeout) + c.secretNames, requestID, c.certRequestTimeout) // ignore eventual responses for this request. - c.setRequestId("") + c.setRequestID("") c.scheduleRetryIssueCertificatesRefresh() } // should only be called from the loop goroutine. func (c *certManagerImpl) doStop() { - c.setRequestId("") + c.setRequestID("") log.Info("CertManager stopped.") } @@ -257,7 +258,7 @@ func (c *certManagerImpl) fetchSecrets() (map[storage.ServiceType]*v1.Secret, er for serviceType, secretName := range c.secretNames { var ( secret *v1.Secret - err error + err error ) retryErr := retry.OnError(k8sAPIBackoff, func(err error) bool { @@ -268,8 +269,8 @@ func (c *certManagerImpl) fetchSecrets() (map[storage.ServiceType]*v1.Secret, er return err }, ) - if retryErr != nil{ - fetchErr = multierror.Append(fetchErr, errors.Wrapf(retryErr,"for secret %s", secretName)) + if retryErr != nil { + fetchErr = multierror.Append(fetchErr, errors.Wrapf(retryErr, "for secret %s", secretName)) } else { secretsMap[serviceType] = secret } @@ -294,10 +295,9 @@ func (c *certManagerImpl) refreshSecrets(certificates *storage.TypedServiceCerti return c.secretExpiration.GetSecretsDuration(secrets), nil } -type secretsExpirationStrategyImpl struct {} +type secretsExpirationStrategyImpl struct{} func (s *secretsExpirationStrategyImpl) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration { // TODO ROX-8969 return 5 * time.Second } - diff --git a/sensor/kubernetes/certificates/cert_manager_test.go b/sensor/kubernetes/certificates/cert_manager_test.go index 5fd8780b03786..d5fa32543bfaf 100644 --- a/sensor/kubernetes/certificates/cert_manager_test.go +++ b/sensor/kubernetes/certificates/cert_manager_test.go @@ -20,13 +20,14 @@ const ( namespace = "namespace" requestID = "requestID" ) + var ( - centralBackoff = wait.Backoff{ + requestBackoff = wait.Backoff{ Steps: 3, Duration: 10 * time.Millisecond, Factor: 10.0, Jitter: 0.1, - Cap: 2 * time.Second, + Cap: 2 * time.Second, } ) @@ -38,7 +39,7 @@ type certManagerSuite struct { suite.Suite } -func fakeClientSet(secretNames ...string) *fake.Clientset{ +func fakeClientSet(secretNames ...string) *fake.Clientset { secrets := make([]runtime.Object, len(secretNames)) for i, secretName := range secretNames { secrets[i] = &v1.Secret{ObjectMeta: metav1.ObjectMeta{Name: secretName, Namespace: namespace}} @@ -51,22 +52,22 @@ func fakeSecretsClient(secretNames ...string) corev1.SecretInterface { } type fixedSecretsExpirationStrategy struct { - durations []time.Duration + durations []time.Duration invocations int - signal concurrency.ErrorSignal + signal concurrency.ErrorSignal } -func newFixedSecretsExpirationStrategy(durations ...time.Duration) *fixedSecretsExpirationStrategy{ +func newFixedSecretsExpirationStrategy(durations ...time.Duration) *fixedSecretsExpirationStrategy { return &fixedSecretsExpirationStrategy{ durations: durations, - signal: concurrency.NewErrorSignal(), + signal: concurrency.NewErrorSignal(), } } // signals .signal when the last timeout is reached func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) (duration time.Duration) { - s.invocations ++ - if len(s.durations) <= 1{ + s.invocations++ + if len(s.durations) <= 1 { s.signal.Signal() return s.durations[0] } @@ -75,7 +76,6 @@ func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(secrets map[storage. return duration } - func (s *certManagerSuite) TestSuccessfulRefresh() { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -86,18 +86,18 @@ func (s *certManagerSuite) TestSuccessfulRefresh() { storage.ServiceType_SCANNER_DB_SERVICE: secretName, } secretsClient := fakeSecretsClient(secretName) - certManager := newCertManager(secretsClient, secretNames, centralBackoff, + certManager := newCertManager(secretsClient, secretNames, requestBackoff, func(manager CertManager) (requestID string, err error) { // FIXME nil certs - manager.HandleIssueCertificatesResponse(requestID, nil, nil) + s.Require().NoError(manager.HandleIssueCertificatesResponse(requestID, nil, nil)) return requestID, nil - }) - certManager.centralRequestTimeout = 2 * time.Second - expirationStrategy := newFixedSecretsExpirationStrategy(0, 2 * time.Second) + }) + defer certManager.Stop() + certManager.certRequestTimeout = 2 * time.Second + expirationStrategy := newFixedSecretsExpirationStrategy(0, 2*time.Second) certManager.secretExpiration = expirationStrategy - certManager.Start(ctx) - defer certManager.Stop() + s.Require().NoError(certManager.Start(ctx)) // FIXME: idea, add error handler fields to certManagerImpl, that processes errors in loop, and // make the processing functions return err. For prod it just logs; for test it keeps a slice of @@ -112,7 +112,6 @@ func (s *certManagerSuite) TestSuccessfulRefresh() { s.Empty(certManager.requestStatus.requestID) } - /* TODO failures: @@ -122,4 +121,4 @@ TODO failures: - timeout in all check retries as expected -*/ \ No newline at end of file +*/ From 38e1959a19aa18cecabdeaba08a446c7d994cff2 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 19 Jan 2022 12:20:12 +0100 Subject: [PATCH 07/34] Only reset backoff on successful refresh --- sensor/kubernetes/certificates/cert_manager.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sensor/kubernetes/certificates/cert_manager.go b/sensor/kubernetes/certificates/cert_manager.go index 9df42f718d115..22dbdd15212c7 100644 --- a/sensor/kubernetes/certificates/cert_manager.go +++ b/sensor/kubernetes/certificates/cert_manager.go @@ -163,14 +163,18 @@ func (c *certManagerImpl) setCertIssueRequestTimeoutTimer(timer *time.Timer) { c.certIssueRequestTimeoutTimer = timer } -// set request id, and reset timers and retry backoff. +// set request id, and stops timers. func (c *certManagerImpl) setRequestID(requestID string) { c.requestStatus.requestID = requestID - c.requestStatus.backoff = c.certRequestBackoffProto c.setRefreshTimer(nil) c.setCertIssueRequestTimeoutTimer(nil) } +// reset retry backoff. +func (c *certManagerImpl) resetBackoff() { + c.requestStatus.backoff = c.certRequestBackoffProto +} + func (c *certManagerImpl) HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error { if c == nil { return errors.Errorf("unknown request ID %s, potentially due to request timeout", requestID) @@ -201,7 +205,6 @@ func (c *certManagerImpl) doHandleIssueCertificatesResponse(requestID string, is log.Debugf("ignoring issue certificate response from unknown request id %q", requestID) return } - c.setRequestID("") if issueError != nil { // server side error. @@ -218,6 +221,7 @@ func (c *certManagerImpl) doHandleIssueCertificatesResponse(requestID string, is } log.Infof("successfully refreshed credential in secrets %v", c.secretNames) + c.resetBackoff() c.scheduleIssueCertificatesRefresh(nextTimeToRefresh) } @@ -230,8 +234,6 @@ func (c *certManagerImpl) issueCertificatesTimeout(requestID string) { } log.Errorf("timeout waiting for certificates for secrets %v on request with id %q after waiting for %s", c.secretNames, requestID, c.certRequestTimeout) - // ignore eventual responses for this request. - c.setRequestID("") c.scheduleRetryIssueCertificatesRefresh() } @@ -247,6 +249,8 @@ func (c *certManagerImpl) scheduleRetryIssueCertificatesRefresh() { func (c *certManagerImpl) scheduleIssueCertificatesRefresh(timeToRefresh time.Duration) { log.Infof("certificates for secrets %v scheduled to be refreshed in %s", c.secretNames, timeToRefresh) + // ignore eventual responses for this request. + c.setRequestID("") c.setRefreshTimer(time.AfterFunc(timeToRefresh, func() { c.dispatchC <- requestCertificates{} })) From d8f4a1f826c02a0104deea0f46af5edb5439ee7c Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 19 Jan 2022 19:43:58 +0100 Subject: [PATCH 08/34] outline of retries tests --- .../kubernetes/certificates/cert_manager.go | 99 +++++++----- .../certificates/cert_manager_test.go | 149 ++++++++++++++---- 2 files changed, 177 insertions(+), 71 deletions(-) diff --git a/sensor/kubernetes/certificates/cert_manager.go b/sensor/kubernetes/certificates/cert_manager.go index 22dbdd15212c7..359ff3796fdbf 100644 --- a/sensor/kubernetes/certificates/cert_manager.go +++ b/sensor/kubernetes/certificates/cert_manager.go @@ -36,6 +36,14 @@ type secretsExpirationStrategy interface { GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration } +type errorReporter interface { + Report(err error) +} + +type jobScheduler interface { + AfterFunc(d time.Duration, f func()) *time.Timer +} + // CertManager is in charge of storing and refreshing service TLS certificates in a set of k8s secrets. type CertManager interface { Start(ctx context.Context) error @@ -56,7 +64,9 @@ type certManagerImpl struct { stopC concurrency.ErrorSignal certRequestTimeout time.Duration certRequestBackoffProto wait.Backoff - secretExpiration secretsExpirationStrategy + expirationStrategy secretsExpirationStrategy + errorReporter errorReporter + jobScheduler jobScheduler // set at Start(). ctx context.Context // handled by loop goroutine. @@ -86,7 +96,9 @@ func newCertManager(secretsClient corev1.SecretInterface, secretNames map[storag stopC: concurrency.NewErrorSignal(), certRequestTimeout: defaultCertRequestTimeout, certRequestBackoffProto: certRequestBackoff, - secretExpiration: &secretsExpirationStrategyImpl{}, + expirationStrategy: &secretsExpirationStrategyImpl{}, + errorReporter: &errorReporterImpl{}, + jobScheduler: &jobSchedulerImpl{}, dispatchC: make(chan interface{}, internalChannelBuffSize), requestStatus: &requestStatus{}, } @@ -99,7 +111,7 @@ func (c *certManagerImpl) Start(ctx context.Context) error { return errors.Wrapf(err, "fetching secrets %v", c.secretNames) } // this refreshes immediately if certificates are already expired. - c.scheduleIssueCertificatesRefresh(c.secretExpiration.GetSecretsDuration(secrets)) + c.scheduleIssueCertificatesRefresh(c.expirationStrategy.GetSecretsDuration(secrets)) go c.loop() @@ -110,10 +122,6 @@ func (c *certManagerImpl) Stop() { c.stopC.Signal() } -func (c *certManagerImpl) issueCertificates() (requestID string, err error) { - return c.issueCerts(c) -} - func (c *certManagerImpl) loop() { // FIXME: protect private methods and fields for { @@ -121,17 +129,17 @@ func (c *certManagerImpl) loop() { case msg := <-c.dispatchC: switch m := msg.(type) { case requestCertificates: - c.requestCertificates() + c.errorReporter.Report(c.requestCertificates()) case handleIssueCertificatesResponse: - c.doHandleIssueCertificatesResponse(m.requestID, m.issueError, m.certificates) + c.errorReporter.Report(c.handleIssueCertificatesResponse(m.requestID, m.issueError, m.certificates)) case issueCertificatesTimeout: - c.issueCertificatesTimeout(m.requestID) + c.errorReporter.Report(c.issueCertificatesTimeout(m.requestID)) default: - log.Errorf("received unknown message %v, message will be ignored", msg) + c.errorReporter.Report(errors.Errorf("received unknown message %v, message will be ignored", msg)) } case <-c.stopC.Done(): - c.doStop() + c.errorReporter.Report(c.doStop()) return } } @@ -184,63 +192,66 @@ func (c *certManagerImpl) HandleIssueCertificatesResponse(requestID string, issu } // should only be called from the loop goroutine. -func (c *certManagerImpl) requestCertificates() { - if requestID, err := c.issueCertificates(); err != nil { +func (c *certManagerImpl) requestCertificates() error { + requestID, err := c.issueCerts(c) + if err != nil { // client side error - log.Errorf("client error sending request to issue certificates for secrets %v: %s", - c.secretNames, err) c.scheduleRetryIssueCertificatesRefresh() - } else { - c.setRequestID(requestID) - c.setCertIssueRequestTimeoutTimer(time.AfterFunc(c.certRequestTimeout, func() { - c.dispatchC <- issueCertificatesTimeout{requestID: requestID} - })) + return errors.Wrapf(err, "client error sending request to issue certificates for secrets %v", + c.secretNames) } + c.setRequestID(requestID) + c.setCertIssueRequestTimeoutTimer(c.jobScheduler.AfterFunc(c.certRequestTimeout, func() { + log.Debugf("request with id %q will timeout in %s", requestID, c.certRequestTimeout) + c.dispatchC <- issueCertificatesTimeout{requestID: requestID} + })) + return nil } // should only be called from the loop goroutine. -func (c *certManagerImpl) doHandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) { +func (c *certManagerImpl) handleIssueCertificatesResponse(requestID string, issueError error, + certificates *storage.TypedServiceCertificateSet) error { if requestID != c.requestStatus.requestID { // silently ignore responses sent to the wrong CertManager. log.Debugf("ignoring issue certificate response from unknown request id %q", requestID) - return + return nil } if issueError != nil { // server side error. - log.Errorf("server side error issuing certificates for secrets %v: %s", c.secretNames, issueError) c.scheduleRetryIssueCertificatesRefresh() - return + return errors.Wrapf(issueError, "server side error issuing certificates for secrets %v", c.secretNames) } nextTimeToRefresh, refreshErr := c.refreshSecrets(certificates) if refreshErr != nil { - log.Errorf("failure to store the new certificates in the secrets %v: %s", c.secretNames, refreshErr) c.scheduleRetryIssueCertificatesRefresh() - return + return errors.Wrapf(refreshErr, "failure to store the new certificates in the secrets %v", c.secretNames) } log.Infof("successfully refreshed credential in secrets %v", c.secretNames) c.resetBackoff() c.scheduleIssueCertificatesRefresh(nextTimeToRefresh) + return nil } // should only be called from the loop goroutine. -func (c *certManagerImpl) issueCertificatesTimeout(requestID string) { +func (c *certManagerImpl) issueCertificatesTimeout(requestID string) error { if requestID != c.requestStatus.requestID { // this is a timeout for a request we don't care about anymore. log.Debugf("ignoring timeout on issue certificate request from unknown request id %q", requestID) - return + return nil } - log.Errorf("timeout waiting for certificates for secrets %v on request with id %q after waiting for %s", - c.secretNames, requestID, c.certRequestTimeout) c.scheduleRetryIssueCertificatesRefresh() + return errors.Errorf("timeout waiting for certificates for secrets %v on request with id %q after waiting "+ + "for %s", c.secretNames, requestID, c.certRequestTimeout) } // should only be called from the loop goroutine. -func (c *certManagerImpl) doStop() { +func (c *certManagerImpl) doStop() error { c.setRequestID("") - log.Info("CertManager stopped.") + log.Infof("cert manager for secrets %v stopped.", c.secretNames) // FIXME + return nil } func (c *certManagerImpl) scheduleRetryIssueCertificatesRefresh() { @@ -251,7 +262,7 @@ func (c *certManagerImpl) scheduleIssueCertificatesRefresh(timeToRefresh time.Du log.Infof("certificates for secrets %v scheduled to be refreshed in %s", c.secretNames, timeToRefresh) // ignore eventual responses for this request. c.setRequestID("") - c.setRefreshTimer(time.AfterFunc(timeToRefresh, func() { + c.setRefreshTimer(c.jobScheduler.AfterFunc(timeToRefresh, func() { c.dispatchC <- requestCertificates{} })) } @@ -294,14 +305,28 @@ func (c *certManagerImpl) refreshSecrets(certificates *storage.TypedServiceCerti // FIXME wrap return 0, err } - // TODO update secrets ROX-8969 + // TODO update secrets ROX-9014 - return c.secretExpiration.GetSecretsDuration(secrets), nil + return c.expirationStrategy.GetSecretsDuration(secrets), nil } type secretsExpirationStrategyImpl struct{} func (s *secretsExpirationStrategyImpl) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration { - // TODO ROX-8969 + // TODO ROX-9014 return 5 * time.Second } + +type errorReporterImpl struct{} + +func (*errorReporterImpl) Report(err error) { + if err != nil { + log.Error(err) + } +} + +type jobSchedulerImpl struct{} + +func (*jobSchedulerImpl) AfterFunc(d time.Duration, f func()) *time.Timer { + return time.AfterFunc(d, f) +} diff --git a/sensor/kubernetes/certificates/cert_manager_test.go b/sensor/kubernetes/certificates/cert_manager_test.go index d5fa32543bfaf..bfeda948c57da 100644 --- a/sensor/kubernetes/certificates/cert_manager_test.go +++ b/sensor/kubernetes/certificates/cert_manager_test.go @@ -7,6 +7,8 @@ import ( "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/concurrency" + "github.com/stackrox/rox/pkg/uuid" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -18,7 +20,6 @@ import ( const ( namespace = "namespace" - requestID = "requestID" ) var ( @@ -37,6 +38,81 @@ func TestHandler(t *testing.T) { type certManagerSuite struct { suite.Suite + ctx context.Context + cancelCtx context.CancelFunc + errReporter *recordErrorReporter + scheduler *mockJobScheduler + certManager *certManagerImpl +} + +func (s *certManagerSuite) TearDownTest() { + if s.cancelCtx != nil { + s.cancelCtx() + } + if s.certManager != nil { + s.certManager.Stop() + } + + log.Warn("FIXME") +} + +func (s *certManagerSuite) initialize(testTimeout time.Duration, + secretNamesMap map[storage.ServiceType]string, + certRequestTimeout time.Duration, expirations []time.Duration, + issueCerts CertIssuanceFunc) { + ctx := context.Background() + s.ctx, s.cancelCtx = context.WithTimeout(ctx, testTimeout) + + secretNames := make([]string, len(secretNamesMap)) + for _, secretName := range secretNamesMap { + secretNames = append(secretNames, secretName) + } + secretsClient := fakeSecretsClient(secretNames...) + + s.errReporter = newRecordErrorReporter(3) + s.scheduler = newMockJobScheduler() + + certManager := newCertManager(secretsClient, secretNamesMap, requestBackoff, issueCerts) + certManager.certRequestTimeout = certRequestTimeout + certManager.expirationStrategy = newFixedSecretsExpirationStrategy(expirations...) + certManager.errorReporter = s.errReporter + certManager.jobScheduler = s.scheduler + s.certManager = certManager +} + +func (s *certManagerSuite) TestSuccessfulInitialRefresh() { + secretNames := map[storage.ServiceType]string{ + storage.ServiceType_SCANNER_DB_SERVICE: "foo", + } + certRequestTimeout := 3 * time.Second + expirations := []time.Duration{0, 2 * time.Second} + s.initialize(time.Second, secretNames, certRequestTimeout, expirations, + // FIXME replace by mock method to assert on requestCertificates + func(manager CertManager) (string, error) { + requestID := uuid.NewV4().String() + go func() { + // TODO non nil certs ROX-9014 + s.Require().NoError(manager.HandleIssueCertificatesResponse(requestID, nil, nil)) + }() + + return requestID, nil + }) + + s.scheduler.On("AfterFunc", expirations[0], mock.Anything).Once() + s.scheduler.On("AfterFunc", s.certManager.certRequestTimeout, mock.Anything).Once() + s.scheduler.On("AfterFunc", expirations[1], mock.Anything).Once().Run(func(mock.Arguments) { + s.certManager.Stop() + }) + + s.Require().NoError(s.certManager.Start(s.ctx)) + waitErr, ok := s.errReporter.signal.WaitUntil(s.ctx) + s.Require().True(ok) + s.NoError(waitErr) + + s.scheduler.AssertExpectations(s.T()) + // requestCertificates, handleIssueCertificatesResponse, stop + s.Equal([]error{nil, nil, nil}, s.errReporter.errors) + // TODO: assert timers nil, retry reset, request id nil } func fakeClientSet(secretNames ...string) *fake.Clientset { @@ -54,21 +130,18 @@ func fakeSecretsClient(secretNames ...string) corev1.SecretInterface { type fixedSecretsExpirationStrategy struct { durations []time.Duration invocations int - signal concurrency.ErrorSignal } func newFixedSecretsExpirationStrategy(durations ...time.Duration) *fixedSecretsExpirationStrategy { return &fixedSecretsExpirationStrategy{ durations: durations, - signal: concurrency.NewErrorSignal(), } } -// signals .signal when the last timeout is reached -func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) (duration time.Duration) { +// returns the last duration forever when it runs out of durations +func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(map[storage.ServiceType]*v1.Secret) (duration time.Duration) { s.invocations++ if len(s.durations) <= 1 { - s.signal.Signal() return s.durations[0] } @@ -76,40 +149,46 @@ func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(secrets map[storage. return duration } -func (s *certManagerSuite) TestSuccessfulRefresh() { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) - defer cancel() +// the reporter will Signal() its signal as soon as numErrorsToSignal are reported. +type recordErrorReporter struct { + reporter errorReporter + errors []error + numErrorsToSignal int + signal concurrency.ErrorSignal +} - secretName := "foo" - secretNames := map[storage.ServiceType]string{ - storage.ServiceType_SCANNER_DB_SERVICE: secretName, +func (r *recordErrorReporter) Report(err error) { + r.errors = append(r.errors, err) + r.reporter.Report(err) + if len(r.errors) >= r.numErrorsToSignal { + r.signal.Signal() } - secretsClient := fakeSecretsClient(secretName) - certManager := newCertManager(secretsClient, secretNames, requestBackoff, - func(manager CertManager) (requestID string, err error) { - // FIXME nil certs - s.Require().NoError(manager.HandleIssueCertificatesResponse(requestID, nil, nil)) - return requestID, nil - }) - defer certManager.Stop() - certManager.certRequestTimeout = 2 * time.Second - expirationStrategy := newFixedSecretsExpirationStrategy(0, 2*time.Second) - certManager.secretExpiration = expirationStrategy - - s.Require().NoError(certManager.Start(ctx)) +} - // FIXME: idea, add error handler fields to certManagerImpl, that processes errors in loop, and - // make the processing functions return err. For prod it just logs; for test it keeps a slice of - // errors or something we can inspect +func newRecordErrorReporter(numErrorsToSignal int) *recordErrorReporter { + return &recordErrorReporter{ + reporter: &errorReporterImpl{}, + signal: concurrency.NewErrorSignal(), + numErrorsToSignal: numErrorsToSignal, + } +} - waitErr, ok := expirationStrategy.signal.WaitUntil(ctx) - s.Require().True(ok) - s.NoError(waitErr) +// AfterFunc records the call in the mock, and then returns AfterFunc() for the +// wrapped scheduler. +type mockJobScheduler struct { + mock.Mock + scheduler jobScheduler +} - // TODO assert certManager not stopped +func (s *mockJobScheduler) AfterFunc(d time.Duration, f func()) *time.Timer { + s.Called(d, f) + return s.scheduler.AfterFunc(d, f) +} - s.Empty(certManager.requestStatus.requestID) +func newMockJobScheduler() *mockJobScheduler { + return &mockJobScheduler{ + scheduler: &jobSchedulerImpl{}, + } } /* @@ -119,6 +198,8 @@ TODO failures: - server failure - client failure - timeout +- unknown request ids +- nil cert manager in all check retries as expected */ From 19c8b3f5ae43f6678113f0e5fc94822ca7eb9701 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Fri, 21 Jan 2022 15:00:39 +0100 Subject: [PATCH 09/34] Move retries with timeout and secret handling out of cert refresher So it is just coordination --- .../kubernetes/certificates/cert_manager.go | 332 ------------------ .../certificates/cert_manager_test.go | 205 ----------- .../kubernetes/certificates/cert_refresher.go | 160 +++++++++ 3 files changed, 160 insertions(+), 537 deletions(-) delete mode 100644 sensor/kubernetes/certificates/cert_manager.go delete mode 100644 sensor/kubernetes/certificates/cert_manager_test.go create mode 100644 sensor/kubernetes/certificates/cert_refresher.go diff --git a/sensor/kubernetes/certificates/cert_manager.go b/sensor/kubernetes/certificates/cert_manager.go deleted file mode 100644 index 359ff3796fdbf..0000000000000 --- a/sensor/kubernetes/certificates/cert_manager.go +++ /dev/null @@ -1,332 +0,0 @@ -package certificates - -import ( - "context" - "time" - - "github.com/hashicorp/go-multierror" - "github.com/pkg/errors" - "github.com/stackrox/rox/generated/storage" - "github.com/stackrox/rox/pkg/concurrency" - "github.com/stackrox/rox/pkg/logging" - v1 "k8s.io/api/core/v1" - k8sErrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/wait" - corev1 "k8s.io/client-go/kubernetes/typed/core/v1" - "k8s.io/client-go/util/retry" -) - -const ( - // FIXME adjust - internalChannelBuffSize = 50 - defaultCertRequestTimeout = time.Minute -) - -var ( - log = logging.LoggerForModule() - // FIXME adjust - k8sAPIBackoff = retry.DefaultBackoff - _ secretsExpirationStrategy = (*secretsExpirationStrategyImpl)(nil) - - _ CertManager = (*certManagerImpl)(nil) -) - -type secretsExpirationStrategy interface { - GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration -} - -type errorReporter interface { - Report(err error) -} - -type jobScheduler interface { - AfterFunc(d time.Duration, f func()) *time.Timer -} - -// CertManager is in charge of storing and refreshing service TLS certificates in a set of k8s secrets. -type CertManager interface { - Start(ctx context.Context) error - Stop() - // HandleIssueCertificatesResponse handles a certificate issue response. - // - Precondition: if issueError is nil then certificates is not nil. - // - Implementations should handle a nil receiver like an unknown request ID. - HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error -} - -// CertIssuanceFunc can be used to request a certificate. -type CertIssuanceFunc func(CertManager) (requestID string, err error) -type certManagerImpl struct { - // should be kept constant. - secretNames map[storage.ServiceType]string - secretsClient corev1.SecretInterface - issueCerts CertIssuanceFunc - stopC concurrency.ErrorSignal - certRequestTimeout time.Duration - certRequestBackoffProto wait.Backoff - expirationStrategy secretsExpirationStrategy - errorReporter errorReporter - jobScheduler jobScheduler - // set at Start(). - ctx context.Context - // handled by loop goroutine. - dispatchC chan interface{} - requestStatus *requestStatus - refreshTimer *time.Timer - certIssueRequestTimeoutTimer *time.Timer -} - -type requestStatus struct { - requestID string - backoff wait.Backoff -} - -// NewCertManager creates a new CertManager. -func NewCertManager(secretsClient corev1.SecretInterface, secretNames map[storage.ServiceType]string, - certRequestBackoff wait.Backoff, issueCerts CertIssuanceFunc) CertManager { - return newCertManager(secretsClient, secretNames, certRequestBackoff, issueCerts) -} - -func newCertManager(secretsClient corev1.SecretInterface, secretNames map[storage.ServiceType]string, - certRequestBackoff wait.Backoff, issueCerts CertIssuanceFunc) *certManagerImpl { - return &certManagerImpl{ - secretNames: secretNames, - secretsClient: secretsClient, - issueCerts: issueCerts, - stopC: concurrency.NewErrorSignal(), - certRequestTimeout: defaultCertRequestTimeout, - certRequestBackoffProto: certRequestBackoff, - expirationStrategy: &secretsExpirationStrategyImpl{}, - errorReporter: &errorReporterImpl{}, - jobScheduler: &jobSchedulerImpl{}, - dispatchC: make(chan interface{}, internalChannelBuffSize), - requestStatus: &requestStatus{}, - } -} - -func (c *certManagerImpl) Start(ctx context.Context) error { - c.ctx = ctx - secrets, err := c.fetchSecrets() - if err != nil { - return errors.Wrapf(err, "fetching secrets %v", c.secretNames) - } - // this refreshes immediately if certificates are already expired. - c.scheduleIssueCertificatesRefresh(c.expirationStrategy.GetSecretsDuration(secrets)) - - go c.loop() - - return nil -} - -func (c *certManagerImpl) Stop() { - c.stopC.Signal() -} - -func (c *certManagerImpl) loop() { - // FIXME: protect private methods and fields - for { - select { - case msg := <-c.dispatchC: - switch m := msg.(type) { - case requestCertificates: - c.errorReporter.Report(c.requestCertificates()) - case handleIssueCertificatesResponse: - c.errorReporter.Report(c.handleIssueCertificatesResponse(m.requestID, m.issueError, m.certificates)) - case issueCertificatesTimeout: - c.errorReporter.Report(c.issueCertificatesTimeout(m.requestID)) - default: - c.errorReporter.Report(errors.Errorf("received unknown message %v, message will be ignored", msg)) - } - - case <-c.stopC.Done(): - c.errorReporter.Report(c.doStop()) - return - } - } -} - -type handleIssueCertificatesResponse struct { - requestID string - issueError error - certificates *storage.TypedServiceCertificateSet -} - -type requestCertificates struct{} - -type issueCertificatesTimeout struct { - requestID string -} - -func (c *certManagerImpl) setRefreshTimer(timer *time.Timer) { - if c.refreshTimer != nil { - c.refreshTimer.Stop() - } - c.refreshTimer = timer -} - -func (c *certManagerImpl) setCertIssueRequestTimeoutTimer(timer *time.Timer) { - if c.certIssueRequestTimeoutTimer != nil { - c.certIssueRequestTimeoutTimer.Stop() - } - c.certIssueRequestTimeoutTimer = timer -} - -// set request id, and stops timers. -func (c *certManagerImpl) setRequestID(requestID string) { - c.requestStatus.requestID = requestID - c.setRefreshTimer(nil) - c.setCertIssueRequestTimeoutTimer(nil) -} - -// reset retry backoff. -func (c *certManagerImpl) resetBackoff() { - c.requestStatus.backoff = c.certRequestBackoffProto -} - -func (c *certManagerImpl) HandleIssueCertificatesResponse(requestID string, issueError error, certificates *storage.TypedServiceCertificateSet) error { - if c == nil { - return errors.Errorf("unknown request ID %s, potentially due to request timeout", requestID) - } - c.dispatchC <- handleIssueCertificatesResponse{requestID: requestID, issueError: issueError, certificates: certificates} - return nil -} - -// should only be called from the loop goroutine. -func (c *certManagerImpl) requestCertificates() error { - requestID, err := c.issueCerts(c) - if err != nil { - // client side error - c.scheduleRetryIssueCertificatesRefresh() - return errors.Wrapf(err, "client error sending request to issue certificates for secrets %v", - c.secretNames) - } - c.setRequestID(requestID) - c.setCertIssueRequestTimeoutTimer(c.jobScheduler.AfterFunc(c.certRequestTimeout, func() { - log.Debugf("request with id %q will timeout in %s", requestID, c.certRequestTimeout) - c.dispatchC <- issueCertificatesTimeout{requestID: requestID} - })) - return nil -} - -// should only be called from the loop goroutine. -func (c *certManagerImpl) handleIssueCertificatesResponse(requestID string, issueError error, - certificates *storage.TypedServiceCertificateSet) error { - if requestID != c.requestStatus.requestID { - // silently ignore responses sent to the wrong CertManager. - log.Debugf("ignoring issue certificate response from unknown request id %q", requestID) - return nil - } - - if issueError != nil { - // server side error. - c.scheduleRetryIssueCertificatesRefresh() - return errors.Wrapf(issueError, "server side error issuing certificates for secrets %v", c.secretNames) - } - - nextTimeToRefresh, refreshErr := c.refreshSecrets(certificates) - if refreshErr != nil { - c.scheduleRetryIssueCertificatesRefresh() - return errors.Wrapf(refreshErr, "failure to store the new certificates in the secrets %v", c.secretNames) - } - - log.Infof("successfully refreshed credential in secrets %v", c.secretNames) - c.resetBackoff() - c.scheduleIssueCertificatesRefresh(nextTimeToRefresh) - return nil -} - -// should only be called from the loop goroutine. -func (c *certManagerImpl) issueCertificatesTimeout(requestID string) error { - if requestID != c.requestStatus.requestID { - // this is a timeout for a request we don't care about anymore. - log.Debugf("ignoring timeout on issue certificate request from unknown request id %q", requestID) - return nil - } - c.scheduleRetryIssueCertificatesRefresh() - return errors.Errorf("timeout waiting for certificates for secrets %v on request with id %q after waiting "+ - "for %s", c.secretNames, requestID, c.certRequestTimeout) -} - -// should only be called from the loop goroutine. -func (c *certManagerImpl) doStop() error { - c.setRequestID("") - log.Infof("cert manager for secrets %v stopped.", c.secretNames) // FIXME - return nil -} - -func (c *certManagerImpl) scheduleRetryIssueCertificatesRefresh() { - c.scheduleIssueCertificatesRefresh(c.requestStatus.backoff.Step()) -} - -func (c *certManagerImpl) scheduleIssueCertificatesRefresh(timeToRefresh time.Duration) { - log.Infof("certificates for secrets %v scheduled to be refreshed in %s", c.secretNames, timeToRefresh) - // ignore eventual responses for this request. - c.setRequestID("") - c.setRefreshTimer(c.jobScheduler.AfterFunc(timeToRefresh, func() { - c.dispatchC <- requestCertificates{} - })) -} - -func (c *certManagerImpl) fetchSecrets() (map[storage.ServiceType]*v1.Secret, error) { - secretsMap := make(map[storage.ServiceType]*v1.Secret, len(c.secretNames)) - var fetchErr error - for serviceType, secretName := range c.secretNames { - var ( - secret *v1.Secret - err error - ) - retryErr := retry.OnError(k8sAPIBackoff, - func(err error) bool { - return !k8sErrors.IsNotFound(err) - }, - func() error { - secret, err = c.secretsClient.Get(c.ctx, secretName, metav1.GetOptions{}) - return err - }, - ) - if retryErr != nil { - fetchErr = multierror.Append(fetchErr, errors.Wrapf(retryErr, "for secret %s", secretName)) - } else { - secretsMap[serviceType] = secret - } - } - - if fetchErr != nil { - return nil, fetchErr - } - return secretsMap, nil -} - -// Performs retries for reads and writes with the k8s API. -// On success, it returns the duration after which the secrets should be refreshed. -func (c *certManagerImpl) refreshSecrets(certificates *storage.TypedServiceCertificateSet) (time.Duration, error) { - secrets, err := c.fetchSecrets() - if err != nil { - // FIXME wrap - return 0, err - } - // TODO update secrets ROX-9014 - - return c.expirationStrategy.GetSecretsDuration(secrets), nil -} - -type secretsExpirationStrategyImpl struct{} - -func (s *secretsExpirationStrategyImpl) GetSecretsDuration(secrets map[storage.ServiceType]*v1.Secret) time.Duration { - // TODO ROX-9014 - return 5 * time.Second -} - -type errorReporterImpl struct{} - -func (*errorReporterImpl) Report(err error) { - if err != nil { - log.Error(err) - } -} - -type jobSchedulerImpl struct{} - -func (*jobSchedulerImpl) AfterFunc(d time.Duration, f func()) *time.Timer { - return time.AfterFunc(d, f) -} diff --git a/sensor/kubernetes/certificates/cert_manager_test.go b/sensor/kubernetes/certificates/cert_manager_test.go deleted file mode 100644 index bfeda948c57da..0000000000000 --- a/sensor/kubernetes/certificates/cert_manager_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package certificates - -import ( - "context" - "testing" - "time" - - "github.com/stackrox/rox/generated/storage" - "github.com/stackrox/rox/pkg/concurrency" - "github.com/stackrox/rox/pkg/uuid" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/client-go/kubernetes/fake" - corev1 "k8s.io/client-go/kubernetes/typed/core/v1" -) - -const ( - namespace = "namespace" -) - -var ( - requestBackoff = wait.Backoff{ - Steps: 3, - Duration: 10 * time.Millisecond, - Factor: 10.0, - Jitter: 0.1, - Cap: 2 * time.Second, - } -) - -func TestHandler(t *testing.T) { - suite.Run(t, new(certManagerSuite)) -} - -type certManagerSuite struct { - suite.Suite - ctx context.Context - cancelCtx context.CancelFunc - errReporter *recordErrorReporter - scheduler *mockJobScheduler - certManager *certManagerImpl -} - -func (s *certManagerSuite) TearDownTest() { - if s.cancelCtx != nil { - s.cancelCtx() - } - if s.certManager != nil { - s.certManager.Stop() - } - - log.Warn("FIXME") -} - -func (s *certManagerSuite) initialize(testTimeout time.Duration, - secretNamesMap map[storage.ServiceType]string, - certRequestTimeout time.Duration, expirations []time.Duration, - issueCerts CertIssuanceFunc) { - ctx := context.Background() - s.ctx, s.cancelCtx = context.WithTimeout(ctx, testTimeout) - - secretNames := make([]string, len(secretNamesMap)) - for _, secretName := range secretNamesMap { - secretNames = append(secretNames, secretName) - } - secretsClient := fakeSecretsClient(secretNames...) - - s.errReporter = newRecordErrorReporter(3) - s.scheduler = newMockJobScheduler() - - certManager := newCertManager(secretsClient, secretNamesMap, requestBackoff, issueCerts) - certManager.certRequestTimeout = certRequestTimeout - certManager.expirationStrategy = newFixedSecretsExpirationStrategy(expirations...) - certManager.errorReporter = s.errReporter - certManager.jobScheduler = s.scheduler - s.certManager = certManager -} - -func (s *certManagerSuite) TestSuccessfulInitialRefresh() { - secretNames := map[storage.ServiceType]string{ - storage.ServiceType_SCANNER_DB_SERVICE: "foo", - } - certRequestTimeout := 3 * time.Second - expirations := []time.Duration{0, 2 * time.Second} - s.initialize(time.Second, secretNames, certRequestTimeout, expirations, - // FIXME replace by mock method to assert on requestCertificates - func(manager CertManager) (string, error) { - requestID := uuid.NewV4().String() - go func() { - // TODO non nil certs ROX-9014 - s.Require().NoError(manager.HandleIssueCertificatesResponse(requestID, nil, nil)) - }() - - return requestID, nil - }) - - s.scheduler.On("AfterFunc", expirations[0], mock.Anything).Once() - s.scheduler.On("AfterFunc", s.certManager.certRequestTimeout, mock.Anything).Once() - s.scheduler.On("AfterFunc", expirations[1], mock.Anything).Once().Run(func(mock.Arguments) { - s.certManager.Stop() - }) - - s.Require().NoError(s.certManager.Start(s.ctx)) - waitErr, ok := s.errReporter.signal.WaitUntil(s.ctx) - s.Require().True(ok) - s.NoError(waitErr) - - s.scheduler.AssertExpectations(s.T()) - // requestCertificates, handleIssueCertificatesResponse, stop - s.Equal([]error{nil, nil, nil}, s.errReporter.errors) - // TODO: assert timers nil, retry reset, request id nil -} - -func fakeClientSet(secretNames ...string) *fake.Clientset { - secrets := make([]runtime.Object, len(secretNames)) - for i, secretName := range secretNames { - secrets[i] = &v1.Secret{ObjectMeta: metav1.ObjectMeta{Name: secretName, Namespace: namespace}} - } - return fake.NewSimpleClientset(secrets...) -} - -func fakeSecretsClient(secretNames ...string) corev1.SecretInterface { - return fakeClientSet(secretNames...).CoreV1().Secrets(namespace) -} - -type fixedSecretsExpirationStrategy struct { - durations []time.Duration - invocations int -} - -func newFixedSecretsExpirationStrategy(durations ...time.Duration) *fixedSecretsExpirationStrategy { - return &fixedSecretsExpirationStrategy{ - durations: durations, - } -} - -// returns the last duration forever when it runs out of durations -func (s *fixedSecretsExpirationStrategy) GetSecretsDuration(map[storage.ServiceType]*v1.Secret) (duration time.Duration) { - s.invocations++ - if len(s.durations) <= 1 { - return s.durations[0] - } - - duration, s.durations = s.durations[0], s.durations[1:] - return duration -} - -// the reporter will Signal() its signal as soon as numErrorsToSignal are reported. -type recordErrorReporter struct { - reporter errorReporter - errors []error - numErrorsToSignal int - signal concurrency.ErrorSignal -} - -func (r *recordErrorReporter) Report(err error) { - r.errors = append(r.errors, err) - r.reporter.Report(err) - if len(r.errors) >= r.numErrorsToSignal { - r.signal.Signal() - } -} - -func newRecordErrorReporter(numErrorsToSignal int) *recordErrorReporter { - return &recordErrorReporter{ - reporter: &errorReporterImpl{}, - signal: concurrency.NewErrorSignal(), - numErrorsToSignal: numErrorsToSignal, - } -} - -// AfterFunc records the call in the mock, and then returns AfterFunc() for the -// wrapped scheduler. -type mockJobScheduler struct { - mock.Mock - scheduler jobScheduler -} - -func (s *mockJobScheduler) AfterFunc(d time.Duration, f func()) *time.Timer { - s.Called(d, f) - return s.scheduler.AfterFunc(d, f) -} - -func newMockJobScheduler() *mockJobScheduler { - return &mockJobScheduler{ - scheduler: &jobSchedulerImpl{}, - } -} - -/* -TODO failures: - -- success -- server failure -- client failure -- timeout -- unknown request ids -- nil cert manager - -in all check retries as expected -*/ diff --git a/sensor/kubernetes/certificates/cert_refresher.go b/sensor/kubernetes/certificates/cert_refresher.go new file mode 100644 index 0000000000000..e866cf6be319b --- /dev/null +++ b/sensor/kubernetes/certificates/cert_refresher.go @@ -0,0 +1,160 @@ +package certificates + +import ( + "context" + "time" + + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/pkg/retry" + + "k8s.io/apimachinery/pkg/util/wait" +) + +const ( + // FIXME adjust + defaultCertRequestTimeout = time.Minute + refreshCrashWaitTime = time.Minute +) + +var ( + log = logging.LoggerForModule() + _ CertRefresher = (*certRefresherImpl)(nil) +) + + +// CertRefresher is in charge of scheduling the refresh of the TLS certificates of a set of services. +type CertRefresher interface { + Start(ctx context.Context) error + Stop() +} +type certRefresherImpl struct { + conf certRefresherConf + ctx context.Context + refreshTimer *time.Timer +} + +type certRefresherConf struct { + certsDescription string + certificateSource CertificateSource + issueCertificates func(context.Context) (*storage.TypedServiceCertificateSet, error) +} + +type CertificateSource interface { + // RetryableSource to fetch certificates of type *storage.TypedServiceCertificateSet. + retry.RetryableSource + // HandleCertificates stores the certificates in some permanent storage, and returns the time until the next + // refresh. + // If certificates are nil then this should initialize or retrieve the certificates from local storage, + // and compute their next refresh time. + HandleCertificates(certificates *storage.TypedServiceCertificateSet) (timeToRefresh time.Duration, err error) +} + +func NewCertRefresher(certsDescription string, certsSource CertificateSource, + certRequestBackoff wait.Backoff) CertRefresher { + return newCertRefresher(certsDescription, certsSource, certRequestBackoff) +} + +func newCertRefresher(certsDescription string, certsSource CertificateSource, + certRequestBackoff wait.Backoff, ) *certRefresherImpl { + return &certRefresherImpl{ + conf: certRefresherConf{ + certsDescription: certsDescription, + certificateSource: certsSource, + issueCertificates: createIssueCertificates(certsDescription, certsSource, certRequestBackoff), + }, + } +} + +// the returned function only fails if it is cancelled with its input context. +func createIssueCertificates(certsDescription string, certsSource retry.RetryableSource, + backoff wait.Backoff) func(context.Context) (*storage.TypedServiceCertificateSet, error) { + retriever := retry.NewRetryableSourceRetriever(backoff, defaultCertRequestTimeout) + retriever.OnError = func(err error, timeToNextRetry time.Duration) { + log.Errorf("error retrieving certificates %s, will retry in %s: %s", + certsDescription, timeToNextRetry, err) + } + retriever.ValidateResult = func(maybeCerts interface{}) bool { + _, ok := maybeCerts.(*storage.TypedServiceCertificateSet) + return ok + } + return func(ctx context.Context) (*storage.TypedServiceCertificateSet, error) { + retriever.Backoff = backoff // reset backoff for each retrieval. + maybeCerts, err := retriever.Run(ctx, certsSource) + if err != nil { + return nil, err + } + certs, ok := maybeCerts.(*storage.TypedServiceCertificateSet) + if !ok { + // this shouldn't happen due to validation + return nil, errors.Errorf("critical error: response %v has unexpected type", maybeCerts) + } + return certs, nil + } +} + +func (c *certRefresherImpl) Start(ctx context.Context) error { + c.ctx = ctx + err := c.initialRefresh() + if err != nil { + return err + } + go func() { + <- c.ctx.Done() + c.Stop() + }() + return nil +} + +func (c *certRefresherImpl) Stop() { + c.setRefreshTimer(nil) + log.Infof("stopped for certificates %s", c.conf.certsDescription) +} + +func (c *certRefresherImpl) initialRefresh() error { + timeToRefresh, err := c.conf.certificateSource.HandleCertificates(nil) + if err != nil { + return errors.Wrapf(err, "critical error processing certificates %s, aborting", c.conf.certsDescription) + } + c.scheduleIssueCertificatesRefresh(timeToRefresh) + return nil +} + +func (c *certRefresherImpl) scheduleIssueCertificatesRefresh(timeToRefresh time.Duration) { + c.setRefreshTimer(time.AfterFunc(timeToRefresh, func() { + certificates, issueErr := c.conf.issueCertificates(c.ctx) + if issueErr != nil { + log.Errorf("critical error issuing certificates %s: %s", + c.conf.certsDescription, issueErr) + c.recoverFromRefreshCrash() + } + nextTimeToRefresh, handleErr := c.conf.certificateSource.HandleCertificates(certificates) + if handleErr != nil { + log.Errorf("critical error processing certificates %s: %s", + c.conf.certsDescription, issueErr) + c.recoverFromRefreshCrash() + } + + log.Infof("successfully refreshed credentials for certificates %v", c.conf.certsDescription) + c.scheduleIssueCertificatesRefresh(nextTimeToRefresh) + })) + log.Infof("credentials for %v scheduled to be refreshed in %s", + c.conf.certsDescription, timeToRefresh) +} + +func (c *certRefresherImpl) setRefreshTimer(timer *time.Timer) { + if c.refreshTimer != nil { + c.refreshTimer.Stop() + } + c.refreshTimer = timer +} + +func (c *certRefresherImpl) recoverFromRefreshCrash() { + // TODO: consider backoff here. + c.setRefreshTimer(time.AfterFunc(refreshCrashWaitTime, func() { + c.initialRefresh() + })) + log.Errorf("refresh process for %s will restart in %s", c.conf.certsDescription, refreshCrashWaitTime) +} + From 6bad26df61ab6553a84668c8b22b6fa20b07498a Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Fri, 21 Jan 2022 15:06:28 +0100 Subject: [PATCH 10/34] fix style From 8a9e93505e06d2a7047d60468f61683ee70b8824 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Fri, 21 Jan 2022 15:06:34 +0100 Subject: [PATCH 11/34] retry failed recovery also fix style --- .../kubernetes/certificates/cert_refresher.go | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/sensor/kubernetes/certificates/cert_refresher.go b/sensor/kubernetes/certificates/cert_refresher.go index e866cf6be319b..41bcf40963634 100644 --- a/sensor/kubernetes/certificates/cert_refresher.go +++ b/sensor/kubernetes/certificates/cert_refresher.go @@ -8,39 +8,39 @@ import ( "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/logging" "github.com/stackrox/rox/pkg/retry" - "k8s.io/apimachinery/pkg/util/wait" ) const ( // FIXME adjust defaultCertRequestTimeout = time.Minute - refreshCrashWaitTime = time.Minute + refreshCrashWaitTime = time.Minute ) var ( - log = logging.LoggerForModule() - _ CertRefresher = (*certRefresherImpl)(nil) + log = logging.LoggerForModule() + _ CertRefresher = (*certRefresherImpl)(nil) ) - // CertRefresher is in charge of scheduling the refresh of the TLS certificates of a set of services. type CertRefresher interface { Start(ctx context.Context) error Stop() } type certRefresherImpl struct { - conf certRefresherConf - ctx context.Context + conf certRefresherConf + ctx context.Context refreshTimer *time.Timer } type certRefresherConf struct { - certsDescription string + certsDescription string certificateSource CertificateSource - issueCertificates func(context.Context) (*storage.TypedServiceCertificateSet, error) + issueCertificates func(context.Context) (*storage.TypedServiceCertificateSet, error) } +// CertificateSource is able to fetch certificates of type *storage.TypedServiceCertificateSet, and +// to process the retrieved certificates. type CertificateSource interface { // RetryableSource to fetch certificates of type *storage.TypedServiceCertificateSet. retry.RetryableSource @@ -51,16 +51,17 @@ type CertificateSource interface { HandleCertificates(certificates *storage.TypedServiceCertificateSet) (timeToRefresh time.Duration, err error) } +// NewCertRefresher creates a new CertRefresher. func NewCertRefresher(certsDescription string, certsSource CertificateSource, certRequestBackoff wait.Backoff) CertRefresher { return newCertRefresher(certsDescription, certsSource, certRequestBackoff) } func newCertRefresher(certsDescription string, certsSource CertificateSource, - certRequestBackoff wait.Backoff, ) *certRefresherImpl { + certRequestBackoff wait.Backoff) *certRefresherImpl { return &certRefresherImpl{ conf: certRefresherConf{ - certsDescription: certsDescription, + certsDescription: certsDescription, certificateSource: certsSource, issueCertificates: createIssueCertificates(certsDescription, certsSource, certRequestBackoff), }, @@ -76,7 +77,7 @@ func createIssueCertificates(certsDescription string, certsSource retry.Retryabl certsDescription, timeToNextRetry, err) } retriever.ValidateResult = func(maybeCerts interface{}) bool { - _, ok := maybeCerts.(*storage.TypedServiceCertificateSet) + _, ok := maybeCerts.(*storage.TypedServiceCertificateSet) return ok } return func(ctx context.Context) (*storage.TypedServiceCertificateSet, error) { @@ -100,8 +101,8 @@ func (c *certRefresherImpl) Start(ctx context.Context) error { if err != nil { return err } - go func() { - <- c.ctx.Done() + go func() { + <-c.ctx.Done() c.Stop() }() return nil @@ -115,7 +116,8 @@ func (c *certRefresherImpl) Stop() { func (c *certRefresherImpl) initialRefresh() error { timeToRefresh, err := c.conf.certificateSource.HandleCertificates(nil) if err != nil { - return errors.Wrapf(err, "critical error processing certificates %s, aborting", c.conf.certsDescription) + return errors.Wrapf(err, "critical error processing stored certificates %s, aborting", + c.conf.certsDescription) } c.scheduleIssueCertificatesRefresh(timeToRefresh) return nil @@ -153,8 +155,11 @@ func (c *certRefresherImpl) setRefreshTimer(timer *time.Timer) { func (c *certRefresherImpl) recoverFromRefreshCrash() { // TODO: consider backoff here. c.setRefreshTimer(time.AfterFunc(refreshCrashWaitTime, func() { - c.initialRefresh() + err := c.initialRefresh() + if err != nil { + log.Error(err) + c.recoverFromRefreshCrash() + } })) log.Errorf("refresh process for %s will restart in %s", c.conf.certsDescription, refreshCrashWaitTime) } - From ad4239cfe92a79556531a5c66f1328af18686500 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 25 Jan 2022 17:58:06 +0100 Subject: [PATCH 12/34] generalize CertRefresher into generic RetryTicker --- pkg/concurrency/retry_ticker.go | 72 ++++++++ pkg/concurrency/retry_ticker_test.go | 111 ++++++++++++ .../kubernetes/certificates/cert_refresher.go | 165 ------------------ 3 files changed, 183 insertions(+), 165 deletions(-) create mode 100644 pkg/concurrency/retry_ticker.go create mode 100644 pkg/concurrency/retry_ticker_test.go delete mode 100644 sensor/kubernetes/certificates/cert_refresher.go diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go new file mode 100644 index 0000000000000..afb52c89dec88 --- /dev/null +++ b/pkg/concurrency/retry_ticker.go @@ -0,0 +1,72 @@ +package concurrency + +import ( + "context" + "time" + + "k8s.io/apimachinery/pkg/util/wait" +) + +// RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. +type RetryTicker struct { + f func(ctx context.Context) (timeToNextTick time.Duration, err error) + tickTimeout time.Duration + backoffPrototype wait.Backoff + backoff wait.Backoff + tickTimer *time.Timer + OnTickSuccess func(nextTimeToTick time.Duration) + OnTickError func(error) +} + +// NewRetryTicker returns a new RetryTicker that calls the function f repeatedly: +// - When started, the RetryTicker calls f immediately, and if it returns an error +// then the RetryTicker will wait the time returned by backoff.Step before calling f again. +// - f must return an error if ctx is cancelled. RetryTicker always call f with a context with a timeout of tickTimeout. +// - On success RetryTicker will reset backoff, and wait the amount of time returned by f before running f again. +func NewRetryTicker(f func(ctx context.Context) (nextTimeToTick time.Duration, err error), + tickTimeout time.Duration, + backoff wait.Backoff) *RetryTicker{ + ticker := &RetryTicker{ + f: f, + tickTimeout: tickTimeout, + backoffPrototype: backoff, + backoff: backoff, + } + return ticker +} + +func (t *RetryTicker) Start() { + t.scheduleTick(0) +} + +func (t *RetryTicker) Stop() { + t.setTickTimer(nil) +} + +func (t *RetryTicker) scheduleTick(timeToTick time.Duration) { + t.setTickTimer(time.AfterFunc(timeToTick, func() { + ctx, cancel := context.WithTimeout(context.Background(), t.tickTimeout) + defer cancel() + + nextTimeToTick, tickErr := t.f(ctx) + if tickErr != nil { + if t.OnTickError != nil { + t.OnTickError(tickErr) + } + t.scheduleTick(t.backoff.Step()) + return + } + if t.OnTickSuccess != nil { + t.OnTickSuccess(nextTimeToTick) + } + t.backoff = t.backoffPrototype // reset backoff strategy + t.scheduleTick(nextTimeToTick) + })) +} + +func (t *RetryTicker) setTickTimer(timer *time.Timer) { + if t.tickTimer != nil { + t.tickTimer.Stop() + } + t.tickTimer = timer +} diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go new file mode 100644 index 0000000000000..d707ceb551317 --- /dev/null +++ b/pkg/concurrency/retry_ticker_test.go @@ -0,0 +1,111 @@ +package concurrency + +import ( + "context" + "errors" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/wait" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +var ( + pollingInterval = 10 * time.Millisecond + epsilonTime = 100 * time.Millisecond + longTime = 2 * time.Second + backoff = wait.Backoff{ + Duration: epsilonTime, + Factor: 1, + Jitter: 0, + Steps: 1, + Cap: epsilonTime, + } +) +func TestHandler(t *testing.T) { + suite.Run(t, new(retryTickerSuite)) +} + +type retryTickerSuite struct { + suite.Suite +} + +type testTickFun struct { + mock.Mock +} + +func (f *testTickFun) f(ctx context.Context) (nextTimeToTick time.Duration, err error) { + args := f.Called(ctx) + return args.Get(0).(time.Duration), args.Error(1) +} + +func (f *testTickFun) OnTickSuccess(nextTimeToTick time.Duration) { + f.Called(nextTimeToTick) +} + +func (f *testTickFun) OnTickError(err error) { + f.Called(err) +} + +func (s *retryTickerSuite) TestRetryTicker() { + testCases := map[string]struct { + // timeToNextTick1 time.Duration + // tickErr1 error + // timeToNextTick2 time.Duration + // tickErr2 error + forceError bool + addEventHandlers bool + }{ + "successWithEventHandlers": { forceError: false, addEventHandlers: true}, + "successWithoutEventHandlers": { forceError: false, addEventHandlers: false}, + "oneErrorWithEventHandlers": { forceError: true, addEventHandlers: true}, + "oneErrorWithoutEventHandlers": { forceError: true, addEventHandlers: false}, + } + for tcName, tc := range testCases { + s.Run(tcName, func() { + var done1, done2 Flag + wait1 := 2 * epsilonTime + forcedErr := errors.New("forced") + + m := &testTickFun{} + ticker := NewRetryTicker(m.f, longTime, backoff) + + if !tc.forceError { + m.On("f", mock.Anything).Return(wait1, nil).Run(func(args mock.Arguments) { + done1.Set(true) + }).Once() + } else { + m.On("f", mock.Anything).Return(time.Duration(0), forcedErr).Run(func(args mock.Arguments) { + done1.Set(true) + }).Once() + } + m.On("f", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { + done2.Set(true) + }).Once() + if tc.addEventHandlers { + ticker.OnTickSuccess = m.OnTickSuccess + ticker.OnTickError = m.OnTickError + if !tc.forceError { + m.On("OnTickSuccess", wait1).Once() + } else { + m.On("OnTickError", forcedErr).Once() + } + m.On("OnTickSuccess", longTime).Once() + } + + ticker.Start() + defer ticker.Stop() + + s.True(PollWithTimeout(done1.Get, pollingInterval, epsilonTime)) + if !tc.forceError { + s.True(PollWithTimeout(done2.Get, pollingInterval, wait1 + epsilonTime)) + } else { + s.True(PollWithTimeout(done2.Get, pollingInterval, backoff.Cap + epsilonTime)) + } + + m.AssertExpectations(s.T()) + }) + } +} diff --git a/sensor/kubernetes/certificates/cert_refresher.go b/sensor/kubernetes/certificates/cert_refresher.go deleted file mode 100644 index 41bcf40963634..0000000000000 --- a/sensor/kubernetes/certificates/cert_refresher.go +++ /dev/null @@ -1,165 +0,0 @@ -package certificates - -import ( - "context" - "time" - - "github.com/pkg/errors" - "github.com/stackrox/rox/generated/storage" - "github.com/stackrox/rox/pkg/logging" - "github.com/stackrox/rox/pkg/retry" - "k8s.io/apimachinery/pkg/util/wait" -) - -const ( - // FIXME adjust - defaultCertRequestTimeout = time.Minute - refreshCrashWaitTime = time.Minute -) - -var ( - log = logging.LoggerForModule() - _ CertRefresher = (*certRefresherImpl)(nil) -) - -// CertRefresher is in charge of scheduling the refresh of the TLS certificates of a set of services. -type CertRefresher interface { - Start(ctx context.Context) error - Stop() -} -type certRefresherImpl struct { - conf certRefresherConf - ctx context.Context - refreshTimer *time.Timer -} - -type certRefresherConf struct { - certsDescription string - certificateSource CertificateSource - issueCertificates func(context.Context) (*storage.TypedServiceCertificateSet, error) -} - -// CertificateSource is able to fetch certificates of type *storage.TypedServiceCertificateSet, and -// to process the retrieved certificates. -type CertificateSource interface { - // RetryableSource to fetch certificates of type *storage.TypedServiceCertificateSet. - retry.RetryableSource - // HandleCertificates stores the certificates in some permanent storage, and returns the time until the next - // refresh. - // If certificates are nil then this should initialize or retrieve the certificates from local storage, - // and compute their next refresh time. - HandleCertificates(certificates *storage.TypedServiceCertificateSet) (timeToRefresh time.Duration, err error) -} - -// NewCertRefresher creates a new CertRefresher. -func NewCertRefresher(certsDescription string, certsSource CertificateSource, - certRequestBackoff wait.Backoff) CertRefresher { - return newCertRefresher(certsDescription, certsSource, certRequestBackoff) -} - -func newCertRefresher(certsDescription string, certsSource CertificateSource, - certRequestBackoff wait.Backoff) *certRefresherImpl { - return &certRefresherImpl{ - conf: certRefresherConf{ - certsDescription: certsDescription, - certificateSource: certsSource, - issueCertificates: createIssueCertificates(certsDescription, certsSource, certRequestBackoff), - }, - } -} - -// the returned function only fails if it is cancelled with its input context. -func createIssueCertificates(certsDescription string, certsSource retry.RetryableSource, - backoff wait.Backoff) func(context.Context) (*storage.TypedServiceCertificateSet, error) { - retriever := retry.NewRetryableSourceRetriever(backoff, defaultCertRequestTimeout) - retriever.OnError = func(err error, timeToNextRetry time.Duration) { - log.Errorf("error retrieving certificates %s, will retry in %s: %s", - certsDescription, timeToNextRetry, err) - } - retriever.ValidateResult = func(maybeCerts interface{}) bool { - _, ok := maybeCerts.(*storage.TypedServiceCertificateSet) - return ok - } - return func(ctx context.Context) (*storage.TypedServiceCertificateSet, error) { - retriever.Backoff = backoff // reset backoff for each retrieval. - maybeCerts, err := retriever.Run(ctx, certsSource) - if err != nil { - return nil, err - } - certs, ok := maybeCerts.(*storage.TypedServiceCertificateSet) - if !ok { - // this shouldn't happen due to validation - return nil, errors.Errorf("critical error: response %v has unexpected type", maybeCerts) - } - return certs, nil - } -} - -func (c *certRefresherImpl) Start(ctx context.Context) error { - c.ctx = ctx - err := c.initialRefresh() - if err != nil { - return err - } - go func() { - <-c.ctx.Done() - c.Stop() - }() - return nil -} - -func (c *certRefresherImpl) Stop() { - c.setRefreshTimer(nil) - log.Infof("stopped for certificates %s", c.conf.certsDescription) -} - -func (c *certRefresherImpl) initialRefresh() error { - timeToRefresh, err := c.conf.certificateSource.HandleCertificates(nil) - if err != nil { - return errors.Wrapf(err, "critical error processing stored certificates %s, aborting", - c.conf.certsDescription) - } - c.scheduleIssueCertificatesRefresh(timeToRefresh) - return nil -} - -func (c *certRefresherImpl) scheduleIssueCertificatesRefresh(timeToRefresh time.Duration) { - c.setRefreshTimer(time.AfterFunc(timeToRefresh, func() { - certificates, issueErr := c.conf.issueCertificates(c.ctx) - if issueErr != nil { - log.Errorf("critical error issuing certificates %s: %s", - c.conf.certsDescription, issueErr) - c.recoverFromRefreshCrash() - } - nextTimeToRefresh, handleErr := c.conf.certificateSource.HandleCertificates(certificates) - if handleErr != nil { - log.Errorf("critical error processing certificates %s: %s", - c.conf.certsDescription, issueErr) - c.recoverFromRefreshCrash() - } - - log.Infof("successfully refreshed credentials for certificates %v", c.conf.certsDescription) - c.scheduleIssueCertificatesRefresh(nextTimeToRefresh) - })) - log.Infof("credentials for %v scheduled to be refreshed in %s", - c.conf.certsDescription, timeToRefresh) -} - -func (c *certRefresherImpl) setRefreshTimer(timer *time.Timer) { - if c.refreshTimer != nil { - c.refreshTimer.Stop() - } - c.refreshTimer = timer -} - -func (c *certRefresherImpl) recoverFromRefreshCrash() { - // TODO: consider backoff here. - c.setRefreshTimer(time.AfterFunc(refreshCrashWaitTime, func() { - err := c.initialRefresh() - if err != nil { - log.Error(err) - c.recoverFromRefreshCrash() - } - })) - log.Errorf("refresh process for %s will restart in %s", c.conf.certsDescription, refreshCrashWaitTime) -} From 539fb70ba5fe97b31a90d933b26ed56920f5f836 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 25 Jan 2022 18:00:06 +0100 Subject: [PATCH 13/34] remove RetryableSource --- pkg/retry/retry_source.go | 108 -------------------------------------- 1 file changed, 108 deletions(-) delete mode 100644 pkg/retry/retry_source.go diff --git a/pkg/retry/retry_source.go b/pkg/retry/retry_source.go deleted file mode 100644 index f85440eeaa158..0000000000000 --- a/pkg/retry/retry_source.go +++ /dev/null @@ -1,108 +0,0 @@ -package retry - -import ( - "context" - "time" - - "github.com/pkg/errors" - "k8s.io/apimachinery/pkg/util/wait" -) - -// RetryableSource is a proxy with an object that is able to compute a result, but -// that might forget our request, or return an error result, and that returns the -// result asynchronously. -// AskForResult() can be called to request a result, that should be make available in the -// returned channel. Each time AskForResult() is called the previously returned channel is abandoned. -// Retry() can be called several times to retry the result computation, the -// RetryableSource is in charge of handling the cancellation of the computation if needed. -type RetryableSource interface { - AskForResult(ctx context.Context) chan *Result - Retry() -} - -// Result wraps a pair (result, err) produced by a source. By convention -// either err or v has the zero value of its type. -type Result struct { - v interface{} - err error -} - -// RetryableSourceRetriever be used to retrieve the result in a RetryableSource. -type RetryableSourceRetriever struct { - // time to consider failed a call to AskForResult() that didn't return a result yet. - RequestTimeout time.Duration - // optionally specify a function to invoke on each error. waitDuration is the time until - // the next retry. - OnError func(err error, timeToNextRetry time.Duration) - // optionally specify a validation function for each result. - ValidateResult func(interface{}) bool - // should be reset between calls to Run. - Backoff wait.Backoff - timeoutC chan struct{} - timeoutTimer *time.Timer -} - -// NewRetryableSourceRetriever create a new NewRetryableSourceRetriever -func NewRetryableSourceRetriever(backoff wait.Backoff, requestTimeout time.Duration) *RetryableSourceRetriever { - return &RetryableSourceRetriever{ - RequestTimeout: requestTimeout, - Backoff: backoff, - } -} - -// Run gets the result from the specified source. -// Any timeout in ctx is respected. -func (r *RetryableSourceRetriever) Run(ctx context.Context, source RetryableSource) (interface{}, error) { - r.timeoutC = make(chan struct{}) - - resultC := source.AskForResult(ctx) - r.setTimeoutTimer(r.RequestTimeout) - defer r.setTimeoutTimer(-1) - for { - select { - case <-ctx.Done(): - return nil, errors.New("request cancelled") - case <-r.timeoutC: - // assume result will never come. - r.handleError(errors.New("timeout"), source) - case result := <-resultC: - err := result.err - if err != nil { - r.handleError(err, source) - } else { - if r.ValidateResult != nil && !r.ValidateResult(result.v) { - err := errors.Errorf("validation failed for value %v", result.v) - r.handleError(err, source) - } else { - return result.v, nil - } - } - } - } -} - -func (r *RetryableSourceRetriever) handleError(err error, source RetryableSource) { - waitDuration := r.Backoff.Step() - if r.OnError != nil { - r.OnError(err, waitDuration) - } - r.setTimeoutTimer(-1) - time.AfterFunc(waitDuration, func() { - source.Retry() - r.setTimeoutTimer(r.RequestTimeout) - }) -} - -// use negative timeout to just stop the timer. -func (r *RetryableSourceRetriever) setTimeoutTimer(timeout time.Duration) { - if r.timeoutTimer != nil { - r.timeoutTimer.Stop() - } - if timeout >= 0 { - r.timeoutTimer = time.AfterFunc(timeout, func() { - r.timeoutC <- struct{}{} - }) - } else { - r.timeoutTimer = nil - } -} From f0a9c3268ebc94e020afd76d1f953027378de394 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 25 Jan 2022 18:03:15 +0100 Subject: [PATCH 14/34] fix style --- pkg/concurrency/retry_ticker.go | 22 ++++++++++-------- pkg/concurrency/retry_ticker_test.go | 34 ++++++++++++++-------------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index afb52c89dec88..bc651ee7eac11 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -9,13 +9,13 @@ import ( // RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. type RetryTicker struct { - f func(ctx context.Context) (timeToNextTick time.Duration, err error) - tickTimeout time.Duration + f func(ctx context.Context) (timeToNextTick time.Duration, err error) + tickTimeout time.Duration backoffPrototype wait.Backoff - backoff wait.Backoff - tickTimer *time.Timer - OnTickSuccess func(nextTimeToTick time.Duration) - OnTickError func(error) + backoff wait.Backoff + tickTimer *time.Timer + OnTickSuccess func(nextTimeToTick time.Duration) + OnTickError func(error) } // NewRetryTicker returns a new RetryTicker that calls the function f repeatedly: @@ -25,20 +25,22 @@ type RetryTicker struct { // - On success RetryTicker will reset backoff, and wait the amount of time returned by f before running f again. func NewRetryTicker(f func(ctx context.Context) (nextTimeToTick time.Duration, err error), tickTimeout time.Duration, - backoff wait.Backoff) *RetryTicker{ + backoff wait.Backoff) *RetryTicker { ticker := &RetryTicker{ - f: f, - tickTimeout: tickTimeout, + f: f, + tickTimeout: tickTimeout, backoffPrototype: backoff, - backoff: backoff, + backoff: backoff, } return ticker } +// Start calls t.f and schedules the next tick accordingly. func (t *RetryTicker) Start() { t.scheduleTick(0) } +// Stop cancels this RetryTicker. func (t *RetryTicker) Stop() { t.setTickTimer(nil) } diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index d707ceb551317..a020682a654b5 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -6,24 +6,24 @@ import ( "testing" "time" - "k8s.io/apimachinery/pkg/util/wait" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "k8s.io/apimachinery/pkg/util/wait" ) var ( pollingInterval = 10 * time.Millisecond - epsilonTime = 100 * time.Millisecond - longTime = 2 * time.Second - backoff = wait.Backoff{ + epsilonTime = 100 * time.Millisecond + longTime = 2 * time.Second + backoff = wait.Backoff{ Duration: epsilonTime, - Factor: 1, - Jitter: 0, - Steps: 1, - Cap: epsilonTime, + Factor: 1, + Jitter: 0, + Steps: 1, + Cap: epsilonTime, } ) + func TestHandler(t *testing.T) { suite.Run(t, new(retryTickerSuite)) } @@ -55,13 +55,13 @@ func (s *retryTickerSuite) TestRetryTicker() { // tickErr1 error // timeToNextTick2 time.Duration // tickErr2 error - forceError bool + forceError bool addEventHandlers bool }{ - "successWithEventHandlers": { forceError: false, addEventHandlers: true}, - "successWithoutEventHandlers": { forceError: false, addEventHandlers: false}, - "oneErrorWithEventHandlers": { forceError: true, addEventHandlers: true}, - "oneErrorWithoutEventHandlers": { forceError: true, addEventHandlers: false}, + "successWithEventHandlers": {forceError: false, addEventHandlers: true}, + "successWithoutEventHandlers": {forceError: false, addEventHandlers: false}, + "oneErrorWithEventHandlers": {forceError: true, addEventHandlers: true}, + "oneErrorWithoutEventHandlers": {forceError: true, addEventHandlers: false}, } for tcName, tc := range testCases { s.Run(tcName, func() { @@ -84,7 +84,7 @@ func (s *retryTickerSuite) TestRetryTicker() { m.On("f", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { done2.Set(true) }).Once() - if tc.addEventHandlers { + if tc.addEventHandlers { ticker.OnTickSuccess = m.OnTickSuccess ticker.OnTickError = m.OnTickError if !tc.forceError { @@ -100,9 +100,9 @@ func (s *retryTickerSuite) TestRetryTicker() { s.True(PollWithTimeout(done1.Get, pollingInterval, epsilonTime)) if !tc.forceError { - s.True(PollWithTimeout(done2.Get, pollingInterval, wait1 + epsilonTime)) + s.True(PollWithTimeout(done2.Get, pollingInterval, wait1+epsilonTime)) } else { - s.True(PollWithTimeout(done2.Get, pollingInterval, backoff.Cap + epsilonTime)) + s.True(PollWithTimeout(done2.Get, pollingInterval, backoff.Cap+epsilonTime)) } m.AssertExpectations(s.T()) From 88f52548162136523bd59973bca4f0afc48c38fa Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 26 Jan 2022 10:26:23 +0100 Subject: [PATCH 15/34] cleanup --- pkg/concurrency/retry_ticker_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index a020682a654b5..f950a0d668d85 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -51,10 +51,6 @@ func (f *testTickFun) OnTickError(err error) { func (s *retryTickerSuite) TestRetryTicker() { testCases := map[string]struct { - // timeToNextTick1 time.Duration - // tickErr1 error - // timeToNextTick2 time.Duration - // tickErr2 error forceError bool addEventHandlers bool }{ From fca9cd26901f95b9cd76a4dc3b2b8a624a36b67a Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 26 Jan 2022 10:38:14 +0100 Subject: [PATCH 16/34] make setTickTimer concurrently safe it is called from several goroutines in Start, Stop, and the goroutines created by time.AfterFunc --- pkg/concurrency/retry_ticker.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index bc651ee7eac11..f246ba7a1b89a 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -2,6 +2,7 @@ package concurrency import ( "context" + "sync" "time" "k8s.io/apimachinery/pkg/util/wait" @@ -14,6 +15,7 @@ type RetryTicker struct { backoffPrototype wait.Backoff backoff wait.Backoff tickTimer *time.Timer + tickTimerM sync.Mutex OnTickSuccess func(nextTimeToTick time.Duration) OnTickError func(error) } @@ -67,6 +69,8 @@ func (t *RetryTicker) scheduleTick(timeToTick time.Duration) { } func (t *RetryTicker) setTickTimer(timer *time.Timer) { + t.tickTimerM.Lock() + defer t.tickTimerM.Unlock() if t.tickTimer != nil { t.tickTimer.Stop() } From 773bbfbacabfc7207bd02980890afde61cd6cf83 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 26 Jan 2022 13:43:38 +0100 Subject: [PATCH 17/34] simplify test dropping suite --- pkg/concurrency/retry_ticker_test.go | 30 ++++++++++------------------ 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index f950a0d668d85..eb3fd8e3062bb 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -6,32 +6,24 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" "k8s.io/apimachinery/pkg/util/wait" ) var ( pollingInterval = 10 * time.Millisecond - epsilonTime = 100 * time.Millisecond + capTime = 100 * time.Millisecond longTime = 2 * time.Second backoff = wait.Backoff{ - Duration: epsilonTime, + Duration: capTime, Factor: 1, Jitter: 0, Steps: 1, - Cap: epsilonTime, + Cap: capTime, } ) -func TestHandler(t *testing.T) { - suite.Run(t, new(retryTickerSuite)) -} - -type retryTickerSuite struct { - suite.Suite -} - type testTickFun struct { mock.Mock } @@ -49,7 +41,7 @@ func (f *testTickFun) OnTickError(err error) { f.Called(err) } -func (s *retryTickerSuite) TestRetryTicker() { +func TestRetryTicker(t *testing.T) { testCases := map[string]struct { forceError bool addEventHandlers bool @@ -60,9 +52,9 @@ func (s *retryTickerSuite) TestRetryTicker() { "oneErrorWithoutEventHandlers": {forceError: true, addEventHandlers: false}, } for tcName, tc := range testCases { - s.Run(tcName, func() { + t.Run(tcName, func(t *testing.T) { var done1, done2 Flag - wait1 := 2 * epsilonTime + wait1 := 2 * capTime forcedErr := errors.New("forced") m := &testTickFun{} @@ -94,14 +86,14 @@ func (s *retryTickerSuite) TestRetryTicker() { ticker.Start() defer ticker.Stop() - s.True(PollWithTimeout(done1.Get, pollingInterval, epsilonTime)) + assert.True(t, PollWithTimeout(done1.Get, pollingInterval, capTime)) if !tc.forceError { - s.True(PollWithTimeout(done2.Get, pollingInterval, wait1+epsilonTime)) + assert.True(t, PollWithTimeout(done2.Get, pollingInterval, wait1+capTime)) } else { - s.True(PollWithTimeout(done2.Get, pollingInterval, backoff.Cap+epsilonTime)) + assert.True(t, PollWithTimeout(done2.Get, pollingInterval, backoff.Cap+capTime)) } - m.AssertExpectations(s.T()) + m.AssertExpectations(t) }) } } From 0f85a0b373c6dee8ca34e438d4e0ee9fa2fec3c6 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Wed, 26 Jan 2022 14:19:21 +0100 Subject: [PATCH 18/34] use builder pattern and make event handlers private in order to prevent concurrent modifications after Start --- pkg/concurrency/retry_ticker.go | 105 ++++++++++++++++++++------- pkg/concurrency/retry_ticker_test.go | 8 +- 2 files changed, 82 insertions(+), 31 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index f246ba7a1b89a..7ad6bd5b744ff 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -8,67 +8,64 @@ import ( "k8s.io/apimachinery/pkg/util/wait" ) +var ( + _ RetryTicker = (*retryTickerImpl)(nil) + _ RetryTickerBuilder = (*retryTickerBuilderImpl)(nil) +) + // RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. -type RetryTicker struct { - f func(ctx context.Context) (timeToNextTick time.Duration, err error) +type RetryTicker interface { + Start() + Stop() +} + +type retryTickerImpl struct { + f tickFunc tickTimeout time.Duration backoffPrototype wait.Backoff + onTickSuccess onTickSuccessFunc + onTickError onTickErrorFunc backoff wait.Backoff tickTimer *time.Timer tickTimerM sync.Mutex - OnTickSuccess func(nextTimeToTick time.Duration) - OnTickError func(error) } -// NewRetryTicker returns a new RetryTicker that calls the function f repeatedly: -// - When started, the RetryTicker calls f immediately, and if it returns an error -// then the RetryTicker will wait the time returned by backoff.Step before calling f again. -// - f must return an error if ctx is cancelled. RetryTicker always call f with a context with a timeout of tickTimeout. -// - On success RetryTicker will reset backoff, and wait the amount of time returned by f before running f again. -func NewRetryTicker(f func(ctx context.Context) (nextTimeToTick time.Duration, err error), - tickTimeout time.Duration, - backoff wait.Backoff) *RetryTicker { - ticker := &RetryTicker{ - f: f, - tickTimeout: tickTimeout, - backoffPrototype: backoff, - backoff: backoff, - } - return ticker -} +type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) +type onTickSuccessFunc func(nextTimeToTick time.Duration) +type onTickErrorFunc func(tickErr error) // Start calls t.f and schedules the next tick accordingly. -func (t *RetryTicker) Start() { +func (t *retryTickerImpl) Start() { t.scheduleTick(0) } // Stop cancels this RetryTicker. -func (t *RetryTicker) Stop() { +func (t *retryTickerImpl) Stop() { t.setTickTimer(nil) } -func (t *RetryTicker) scheduleTick(timeToTick time.Duration) { +func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { t.setTickTimer(time.AfterFunc(timeToTick, func() { ctx, cancel := context.WithTimeout(context.Background(), t.tickTimeout) defer cancel() nextTimeToTick, tickErr := t.f(ctx) if tickErr != nil { - if t.OnTickError != nil { - t.OnTickError(tickErr) + if t.onTickError != nil { + t.onTickError(tickErr) } t.scheduleTick(t.backoff.Step()) return } - if t.OnTickSuccess != nil { - t.OnTickSuccess(nextTimeToTick) + if t.onTickSuccess != nil { + t.onTickSuccess(nextTimeToTick) } t.backoff = t.backoffPrototype // reset backoff strategy t.scheduleTick(nextTimeToTick) })) } -func (t *RetryTicker) setTickTimer(timer *time.Timer) { +func (t *retryTickerImpl) setTickTimer(timer *time.Timer) { t.tickTimerM.Lock() defer t.tickTimerM.Unlock() if t.tickTimer != nil { @@ -76,3 +73,55 @@ func (t *RetryTicker) setTickTimer(timer *time.Timer) { } t.tickTimer = timer } + +// RetryTickerBuilder is a builder for RetryTicker objects. +type RetryTickerBuilder interface { + OnTickSuccess(onTickSuccessFunc) RetryTickerBuilder + OnTickError(onTickErrorFunc) RetryTickerBuilder + Build() RetryTicker +} + +// NewRetryTicker returns a new RetryTicker with the minimal parameters. See Build method below for +// details about how that is created. +func NewRetryTicker(f tickFunc, tickTimeout time.Duration, backoff wait.Backoff) RetryTicker { + return NewRetryTickerBuilder(f, tickTimeout, backoff).Build() +} + +// NewRetryTickerBuilder returns a builder for a RetryTicker that has been initialized with its mandatory parameters. +func NewRetryTickerBuilder(f tickFunc, tickTimeout time.Duration, backoff wait.Backoff) RetryTickerBuilder { + return &retryTickerBuilderImpl{f: f, tickTimeout: tickTimeout, backoffPrototype: backoff} +} + +// Build returns a new RetryTicker that calls the function f repeatedly: +// - When started, the RetryTicker calls f immediately, and if that returns an error +// then the RetryTicker will wait the time returned by backoff.Step before calling f again. +// - f must return an error if ctx is cancelled. RetryTicker always call f with a context with a timeout of tickTimeout. +// - On success RetryTicker will reset backoff, and wait the amount of time returned by f before running f again. +func (b *retryTickerBuilderImpl) Build() RetryTicker { + return &retryTickerImpl{ + f: b.f, + tickTimeout: b.tickTimeout, + backoffPrototype: b.backoffPrototype, + onTickSuccess: b.onTickSuccess, + onTickError: b.onTickError, + backoff: b.backoffPrototype, + } +} + +type retryTickerBuilderImpl struct { + f tickFunc + tickTimeout time.Duration + backoffPrototype wait.Backoff + onTickSuccess onTickSuccessFunc + onTickError onTickErrorFunc +} + +func (b *retryTickerBuilderImpl) OnTickSuccess(onTickSuccess onTickSuccessFunc) RetryTickerBuilder { + b.onTickSuccess = onTickSuccess + return b +} + +func (b *retryTickerBuilderImpl) OnTickError(onTickError onTickErrorFunc) RetryTickerBuilder { + b.onTickError = onTickError + return b +} diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index eb3fd8e3062bb..13631168541ab 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -58,7 +58,7 @@ func TestRetryTicker(t *testing.T) { forcedErr := errors.New("forced") m := &testTickFun{} - ticker := NewRetryTicker(m.f, longTime, backoff) + var ticker RetryTicker if !tc.forceError { m.On("f", mock.Anything).Return(wait1, nil).Run(func(args mock.Arguments) { @@ -73,14 +73,16 @@ func TestRetryTicker(t *testing.T) { done2.Set(true) }).Once() if tc.addEventHandlers { - ticker.OnTickSuccess = m.OnTickSuccess - ticker.OnTickError = m.OnTickError + tickerBuilder := NewRetryTickerBuilder(m.f, longTime, backoff) + ticker = tickerBuilder.OnTickSuccess(m.OnTickSuccess).OnTickError(m.OnTickError).Build() if !tc.forceError { m.On("OnTickSuccess", wait1).Once() } else { m.On("OnTickError", forcedErr).Once() } m.On("OnTickSuccess", longTime).Once() + } else { + ticker = NewRetryTicker(m.f, longTime, backoff) } ticker.Start() From 100099f03699f6f99fee9a0d54fc2cccd583fc6e Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 27 Jan 2022 16:38:55 +0100 Subject: [PATCH 19/34] use import from pkg/sync --- pkg/concurrency/retry_ticker.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 7ad6bd5b744ff..4b930a7c31308 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -2,9 +2,9 @@ package concurrency import ( "context" - "sync" "time" + "github.com/stackrox/rox/pkg/sync" "k8s.io/apimachinery/pkg/util/wait" ) From 42af3b5a41de7dd332965c983b08fde4a5dc5db9 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 27 Jan 2022 17:17:38 +0100 Subject: [PATCH 20/34] remove handler functions and builder --- pkg/concurrency/retry_ticker.go | 61 ++++------------------------ pkg/concurrency/retry_ticker_test.go | 22 ++-------- 2 files changed, 13 insertions(+), 70 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 4b930a7c31308..89a6aacd4d453 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -9,8 +9,7 @@ import ( ) var ( - _ RetryTicker = (*retryTickerImpl)(nil) - _ RetryTickerBuilder = (*retryTickerBuilderImpl)(nil) + _ RetryTicker = (*retryTickerImpl)(nil) ) // RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. @@ -20,11 +19,9 @@ type RetryTicker interface { } type retryTickerImpl struct { - f tickFunc + fn tickFunc tickTimeout time.Duration backoffPrototype wait.Backoff - onTickSuccess onTickSuccessFunc - onTickError onTickErrorFunc backoff wait.Backoff tickTimer *time.Timer tickTimerM sync.Mutex @@ -34,7 +31,7 @@ type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error type onTickSuccessFunc func(nextTimeToTick time.Duration) type onTickErrorFunc func(tickErr error) -// Start calls t.f and schedules the next tick accordingly. +// Start calls t.f and schedules the next tick immediately. func (t *retryTickerImpl) Start() { t.scheduleTick(0) } @@ -49,17 +46,11 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), t.tickTimeout) defer cancel() - nextTimeToTick, tickErr := t.f(ctx) + nextTimeToTick, tickErr := t.fn(ctx) if tickErr != nil { - if t.onTickError != nil { - t.onTickError(tickErr) - } t.scheduleTick(t.backoff.Step()) return } - if t.onTickSuccess != nil { - t.onTickSuccess(nextTimeToTick) - } t.backoff = t.backoffPrototype // reset backoff strategy t.scheduleTick(nextTimeToTick) })) @@ -83,45 +74,11 @@ type RetryTickerBuilder interface { // NewRetryTicker returns a new RetryTicker with the minimal parameters. See Build method below for // details about how that is created. -func NewRetryTicker(f tickFunc, tickTimeout time.Duration, backoff wait.Backoff) RetryTicker { - return NewRetryTickerBuilder(f, tickTimeout, backoff).Build() -} - -// NewRetryTickerBuilder returns a builder for a RetryTicker that has been initialized with its mandatory parameters. -func NewRetryTickerBuilder(f tickFunc, tickTimeout time.Duration, backoff wait.Backoff) RetryTickerBuilder { - return &retryTickerBuilderImpl{f: f, tickTimeout: tickTimeout, backoffPrototype: backoff} -} - -// Build returns a new RetryTicker that calls the function f repeatedly: -// - When started, the RetryTicker calls f immediately, and if that returns an error -// then the RetryTicker will wait the time returned by backoff.Step before calling f again. -// - f must return an error if ctx is cancelled. RetryTicker always call f with a context with a timeout of tickTimeout. -// - On success RetryTicker will reset backoff, and wait the amount of time returned by f before running f again. -func (b *retryTickerBuilderImpl) Build() RetryTicker { +func NewRetryTicker(fn tickFunc, tickTimeout time.Duration, backoff wait.Backoff) RetryTicker { return &retryTickerImpl{ - f: b.f, - tickTimeout: b.tickTimeout, - backoffPrototype: b.backoffPrototype, - onTickSuccess: b.onTickSuccess, - onTickError: b.onTickError, - backoff: b.backoffPrototype, + fn: fn, + tickTimeout: tickTimeout, + backoffPrototype: backoff, + backoff: backoff, } } - -type retryTickerBuilderImpl struct { - f tickFunc - tickTimeout time.Duration - backoffPrototype wait.Backoff - onTickSuccess onTickSuccessFunc - onTickError onTickErrorFunc -} - -func (b *retryTickerBuilderImpl) OnTickSuccess(onTickSuccess onTickSuccessFunc) RetryTickerBuilder { - b.onTickSuccess = onTickSuccess - return b -} - -func (b *retryTickerBuilderImpl) OnTickError(onTickError onTickErrorFunc) RetryTickerBuilder { - b.onTickError = onTickError - return b -} diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index 13631168541ab..e5baa45d4918a 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -43,13 +43,10 @@ func (f *testTickFun) OnTickError(err error) { func TestRetryTicker(t *testing.T) { testCases := map[string]struct { - forceError bool - addEventHandlers bool + forceError bool }{ - "successWithEventHandlers": {forceError: false, addEventHandlers: true}, - "successWithoutEventHandlers": {forceError: false, addEventHandlers: false}, - "oneErrorWithEventHandlers": {forceError: true, addEventHandlers: true}, - "oneErrorWithoutEventHandlers": {forceError: true, addEventHandlers: false}, + "success": {forceError: false}, + "oneError": {forceError: true}, } for tcName, tc := range testCases { t.Run(tcName, func(t *testing.T) { @@ -72,18 +69,7 @@ func TestRetryTicker(t *testing.T) { m.On("f", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { done2.Set(true) }).Once() - if tc.addEventHandlers { - tickerBuilder := NewRetryTickerBuilder(m.f, longTime, backoff) - ticker = tickerBuilder.OnTickSuccess(m.OnTickSuccess).OnTickError(m.OnTickError).Build() - if !tc.forceError { - m.On("OnTickSuccess", wait1).Once() - } else { - m.On("OnTickError", forcedErr).Once() - } - m.On("OnTickSuccess", longTime).Once() - } else { - ticker = NewRetryTicker(m.f, longTime, backoff) - } + ticker = NewRetryTicker(m.f, longTime, backoff) ticker.Start() defer ticker.Stop() From 2c0954a9dffbf0dab0291860c486223cc3b7400d Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 27 Jan 2022 17:22:06 +0100 Subject: [PATCH 21/34] use terser field names --- pkg/concurrency/retry_ticker.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 89a6aacd4d453..33337d17daa25 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -19,12 +19,12 @@ type RetryTicker interface { } type retryTickerImpl struct { - fn tickFunc - tickTimeout time.Duration + doFunc tickFunc + timeout time.Duration backoffPrototype wait.Backoff backoff wait.Backoff - tickTimer *time.Timer - tickTimerM sync.Mutex + timer *time.Timer + mutex sync.Mutex } type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) @@ -43,10 +43,10 @@ func (t *retryTickerImpl) Stop() { func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { t.setTickTimer(time.AfterFunc(timeToTick, func() { - ctx, cancel := context.WithTimeout(context.Background(), t.tickTimeout) + ctx, cancel := context.WithTimeout(context.Background(), t.timeout) defer cancel() - nextTimeToTick, tickErr := t.fn(ctx) + nextTimeToTick, tickErr := t.doFunc(ctx) if tickErr != nil { t.scheduleTick(t.backoff.Step()) return @@ -57,12 +57,12 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { } func (t *retryTickerImpl) setTickTimer(timer *time.Timer) { - t.tickTimerM.Lock() - defer t.tickTimerM.Unlock() - if t.tickTimer != nil { - t.tickTimer.Stop() + t.mutex.Lock() + defer t.mutex.Unlock() + if t.timer != nil { + t.timer.Stop() } - t.tickTimer = timer + t.timer = timer } // RetryTickerBuilder is a builder for RetryTicker objects. @@ -74,10 +74,10 @@ type RetryTickerBuilder interface { // NewRetryTicker returns a new RetryTicker with the minimal parameters. See Build method below for // details about how that is created. -func NewRetryTicker(fn tickFunc, tickTimeout time.Duration, backoff wait.Backoff) RetryTicker { +func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { return &retryTickerImpl{ - fn: fn, - tickTimeout: tickTimeout, + doFunc: doFunc, + timeout: timeout, backoffPrototype: backoff, backoff: backoff, } From 62da8c7982614a98f20eabee51c8de15ca3c899c Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Mon, 31 Jan 2022 11:19:32 +0100 Subject: [PATCH 22/34] cleanup code lingering from previous change --- pkg/concurrency/retry_ticker.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 33337d17daa25..156b40ea2be23 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -28,8 +28,6 @@ type retryTickerImpl struct { } type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) -type onTickSuccessFunc func(nextTimeToTick time.Duration) -type onTickErrorFunc func(tickErr error) // Start calls t.f and schedules the next tick immediately. func (t *retryTickerImpl) Start() { @@ -65,13 +63,6 @@ func (t *retryTickerImpl) setTickTimer(timer *time.Timer) { t.timer = timer } -// RetryTickerBuilder is a builder for RetryTicker objects. -type RetryTickerBuilder interface { - OnTickSuccess(onTickSuccessFunc) RetryTickerBuilder - OnTickError(onTickErrorFunc) RetryTickerBuilder - Build() RetryTicker -} - // NewRetryTicker returns a new RetryTicker with the minimal parameters. See Build method below for // details about how that is created. func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { From f21698442c4d886dd05552e0bb20672907776173 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Mon, 31 Jan 2022 11:22:13 +0100 Subject: [PATCH 23/34] rename backoffPrototype to initialBackoff --- pkg/concurrency/retry_ticker.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 156b40ea2be23..cd9c242e79e61 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -19,12 +19,12 @@ type RetryTicker interface { } type retryTickerImpl struct { - doFunc tickFunc - timeout time.Duration - backoffPrototype wait.Backoff - backoff wait.Backoff - timer *time.Timer - mutex sync.Mutex + doFunc tickFunc + timeout time.Duration + initialBackoff wait.Backoff + backoff wait.Backoff + timer *time.Timer + mutex sync.Mutex } type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) @@ -49,7 +49,7 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { t.scheduleTick(t.backoff.Step()) return } - t.backoff = t.backoffPrototype // reset backoff strategy + t.backoff = t.initialBackoff // reset backoff strategy t.scheduleTick(nextTimeToTick) })) } @@ -67,9 +67,9 @@ func (t *retryTickerImpl) setTickTimer(timer *time.Timer) { // details about how that is created. func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { return &retryTickerImpl{ - doFunc: doFunc, - timeout: timeout, - backoffPrototype: backoff, - backoff: backoff, + doFunc: doFunc, + timeout: timeout, + initialBackoff: backoff, + backoff: backoff, } } From aab2d51eaf65fa913b347d7d0ce46ee07b0361dd Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Mon, 31 Jan 2022 11:28:39 +0100 Subject: [PATCH 24/34] cleanup test and make it easier to read --- pkg/concurrency/retry_ticker_test.go | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index e5baa45d4918a..6e0b37ac5d7cf 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -33,20 +33,12 @@ func (f *testTickFun) f(ctx context.Context) (nextTimeToTick time.Duration, err return args.Get(0).(time.Duration), args.Error(1) } -func (f *testTickFun) OnTickSuccess(nextTimeToTick time.Duration) { - f.Called(nextTimeToTick) -} - -func (f *testTickFun) OnTickError(err error) { - f.Called(err) -} - func TestRetryTicker(t *testing.T) { testCases := map[string]struct { - forceError bool + expectError bool }{ - "success": {forceError: false}, - "oneError": {forceError: true}, + "success": {expectError: false}, + "with error should retry": {expectError: true}, } for tcName, tc := range testCases { t.Run(tcName, func(t *testing.T) { @@ -57,12 +49,12 @@ func TestRetryTicker(t *testing.T) { m := &testTickFun{} var ticker RetryTicker - if !tc.forceError { - m.On("f", mock.Anything).Return(wait1, nil).Run(func(args mock.Arguments) { + if tc.expectError { + m.On("f", mock.Anything).Return(time.Duration(0), forcedErr).Run(func(args mock.Arguments) { done1.Set(true) }).Once() } else { - m.On("f", mock.Anything).Return(time.Duration(0), forcedErr).Run(func(args mock.Arguments) { + m.On("f", mock.Anything).Return(wait1, nil).Run(func(args mock.Arguments) { done1.Set(true) }).Once() } @@ -75,10 +67,10 @@ func TestRetryTicker(t *testing.T) { defer ticker.Stop() assert.True(t, PollWithTimeout(done1.Get, pollingInterval, capTime)) - if !tc.forceError { - assert.True(t, PollWithTimeout(done2.Get, pollingInterval, wait1+capTime)) - } else { + if tc.expectError { assert.True(t, PollWithTimeout(done2.Get, pollingInterval, backoff.Cap+capTime)) + } else { + assert.True(t, PollWithTimeout(done2.Get, pollingInterval, wait1+capTime)) } m.AssertExpectations(t) From 5e8046525517674436ff065c2bcc1cc48acf936e Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Mon, 31 Jan 2022 12:21:44 +0100 Subject: [PATCH 25/34] try to make test more readable moving config to test case struct --- pkg/concurrency/retry_ticker_test.go | 41 ++++++++++++---------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index 6e0b37ac5d7cf..6c7cce727303d 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -13,8 +13,8 @@ import ( var ( pollingInterval = 10 * time.Millisecond - capTime = 100 * time.Millisecond - longTime = 2 * time.Second + capTime = 500 * time.Millisecond + longTime = 5 * time.Second backoff = wait.Backoff{ Duration: capTime, Factor: 1, @@ -28,36 +28,28 @@ type testTickFun struct { mock.Mock } -func (f *testTickFun) f(ctx context.Context) (nextTimeToTick time.Duration, err error) { +func (f *testTickFun) f(ctx context.Context) (timeToNextTick time.Duration, err error) { args := f.Called(ctx) return args.Get(0).(time.Duration), args.Error(1) } func TestRetryTicker(t *testing.T) { testCases := map[string]struct { - expectError bool + timeToSecondTick time.Duration + firstErr error }{ - "success": {expectError: false}, - "with error should retry": {expectError: true}, + "success": {timeToSecondTick: 2 * capTime, firstErr: nil}, + "with error should retry": {timeToSecondTick: 0, firstErr: errors.New("forced")}, } for tcName, tc := range testCases { t.Run(tcName, func(t *testing.T) { var done1, done2 Flag - wait1 := 2 * capTime - forcedErr := errors.New("forced") - m := &testTickFun{} var ticker RetryTicker - if tc.expectError { - m.On("f", mock.Anything).Return(time.Duration(0), forcedErr).Run(func(args mock.Arguments) { - done1.Set(true) - }).Once() - } else { - m.On("f", mock.Anything).Return(wait1, nil).Run(func(args mock.Arguments) { - done1.Set(true) - }).Once() - } + m.On("f", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Run(func(args mock.Arguments) { + done1.Set(true) + }).Once() m.On("f", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { done2.Set(true) }).Once() @@ -66,14 +58,17 @@ func TestRetryTicker(t *testing.T) { ticker.Start() defer ticker.Stop() + // this should happen immediately, we add capTime to give some margin to make test more stable. assert.True(t, PollWithTimeout(done1.Get, pollingInterval, capTime)) - if tc.expectError { - assert.True(t, PollWithTimeout(done2.Get, pollingInterval, backoff.Cap+capTime)) + + var expectedTimeToSecondAttempt time.Duration + if tc.firstErr == nil { + expectedTimeToSecondAttempt = tc.timeToSecondTick } else { - assert.True(t, PollWithTimeout(done2.Get, pollingInterval, wait1+capTime)) + expectedTimeToSecondAttempt = backoff.Cap } - - m.AssertExpectations(t) + // we add capTime to give some margin to make test more stable. + assert.True(t, PollWithTimeout(done2.Get, pollingInterval, expectedTimeToSecondAttempt+capTime)) }) } } From 905ba2639ede890203d761358b5ea8fc0c33b54e Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Mon, 31 Jan 2022 12:35:27 +0100 Subject: [PATCH 26/34] improve comments --- pkg/concurrency/retry_ticker.go | 34 +++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index cd9c242e79e61..46564d28cbafa 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -18,6 +18,24 @@ type RetryTicker interface { Stop() } +type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) + +// NewRetryTicker returns a new RetryTicker that calls the "tick function" `doFunc` repeatedly: +// - When started, the RetryTicker calls `doFunc` immediately, and if that returns an error +// then the RetryTicker will wait the time returned by `backoff.Step` before calling `doFunc` again. +// - `doFunc` should return an error if ctx is cancelled. RetryTicker always calls `doFunc` with a context +// with a timeout of `timeout`. +// - On success `RetryTicker` will reset `backoff`, and wait the amount of time returned by `doFunc` before +// running it again. +func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { + return &retryTickerImpl{ + doFunc: doFunc, + timeout: timeout, + initialBackoff: backoff, + backoff: backoff, + } +} + type retryTickerImpl struct { doFunc tickFunc timeout time.Duration @@ -27,14 +45,13 @@ type retryTickerImpl struct { mutex sync.Mutex } -type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) - // Start calls t.f and schedules the next tick immediately. func (t *retryTickerImpl) Start() { t.scheduleTick(0) } -// Stop cancels this RetryTicker. +// Stop cancels this RetryTicker. If Stop is called while the tick function is running then Stop does not +// wait for the tick function to complete before returning. func (t *retryTickerImpl) Stop() { t.setTickTimer(nil) } @@ -62,14 +79,3 @@ func (t *retryTickerImpl) setTickTimer(timer *time.Timer) { } t.timer = timer } - -// NewRetryTicker returns a new RetryTicker with the minimal parameters. See Build method below for -// details about how that is created. -func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { - return &retryTickerImpl{ - doFunc: doFunc, - timeout: timeout, - initialBackoff: backoff, - backoff: backoff, - } -} From 8ce68199ba059a5b1c5f2143ba0eb91faf79e9aa Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 1 Feb 2022 11:24:54 +0100 Subject: [PATCH 27/34] rename mock f function as "doTick" --- pkg/concurrency/retry_ticker.go | 2 +- pkg/concurrency/retry_ticker_test.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 46564d28cbafa..6fac0acd2ed98 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -45,7 +45,7 @@ type retryTickerImpl struct { mutex sync.Mutex } -// Start calls t.f and schedules the next tick immediately. +// Start calls the tick function and schedules the next tick immediately. func (t *retryTickerImpl) Start() { t.scheduleTick(0) } diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index 6c7cce727303d..4df0cef08b01c 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -28,7 +28,7 @@ type testTickFun struct { mock.Mock } -func (f *testTickFun) f(ctx context.Context) (timeToNextTick time.Duration, err error) { +func (f *testTickFun) doTick(ctx context.Context) (timeToNextTick time.Duration, err error) { args := f.Called(ctx) return args.Get(0).(time.Duration), args.Error(1) } @@ -47,13 +47,13 @@ func TestRetryTicker(t *testing.T) { m := &testTickFun{} var ticker RetryTicker - m.On("f", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Run(func(args mock.Arguments) { + m.On("doTick", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Run(func(args mock.Arguments) { done1.Set(true) }).Once() - m.On("f", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { + m.On("doTick", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { done2.Set(true) }).Once() - ticker = NewRetryTicker(m.f, longTime, backoff) + ticker = NewRetryTicker(m.doTick, longTime, backoff) ticker.Start() defer ticker.Stop() From 09695695d2ebc17b1515632bfd501994be5eb739 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 1 Feb 2022 12:52:15 +0100 Subject: [PATCH 28/34] make test runs faster by reducing waits --- pkg/concurrency/retry_ticker.go | 4 +- pkg/concurrency/retry_ticker_test.go | 75 +++++++++++++++++----------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 6fac0acd2ed98..e46af2429eab2 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -29,6 +29,7 @@ type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error // running it again. func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { return &retryTickerImpl{ + scheduler: time.AfterFunc, doFunc: doFunc, timeout: timeout, initialBackoff: backoff, @@ -37,6 +38,7 @@ func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff } type retryTickerImpl struct { + scheduler func(d time.Duration, f func()) *time.Timer doFunc tickFunc timeout time.Duration initialBackoff wait.Backoff @@ -57,7 +59,7 @@ func (t *retryTickerImpl) Stop() { } func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { - t.setTickTimer(time.AfterFunc(timeToTick, func() { + t.setTickTimer(t.scheduler(timeToTick, func() { ctx, cancel := context.WithTimeout(context.Background(), t.timeout) defer cancel() diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index 4df0cef08b01c..2653dfb3c653b 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -8,67 +8,86 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/wait" ) var ( - pollingInterval = 10 * time.Millisecond - capTime = 500 * time.Millisecond - longTime = 5 * time.Second - backoff = wait.Backoff{ + testTimeout = 1 * time.Second + longTime = 5 * time.Second + capTime = 100 * time.Millisecond + backoff = wait.Backoff{ Duration: capTime, Factor: 1, Jitter: 0, - Steps: 1, + Steps: 2, Cap: capTime, } ) -type testTickFun struct { +type testTickFunc struct { mock.Mock } -func (f *testTickFun) doTick(ctx context.Context) (timeToNextTick time.Duration, err error) { +func (f *testTickFunc) doTick(ctx context.Context) (timeToNextTick time.Duration, err error) { args := f.Called(ctx) return args.Get(0).(time.Duration), args.Error(1) } +func (f *testTickFunc) Step() time.Duration { + f.Called() + return 0 +} + +type afterFuncSpy struct { + mock.Mock +} + +func (f *afterFuncSpy) afterFunc(d time.Duration, fn func()) *time.Timer { + f.Called(d) + return time.AfterFunc(d, fn) +} + func TestRetryTicker(t *testing.T) { testCases := map[string]struct { timeToSecondTick time.Duration firstErr error }{ - "success": {timeToSecondTick: 2 * capTime, firstErr: nil}, + "success": {timeToSecondTick: capTime, firstErr: nil}, "with error should retry": {timeToSecondTick: 0, firstErr: errors.New("forced")}, } for tcName, tc := range testCases { t.Run(tcName, func(t *testing.T) { - var done1, done2 Flag - m := &testTickFun{} - var ticker RetryTicker + doneErrSig := NewErrorSignal() + mockFunc := &testTickFunc{} + schedulerSpy := &afterFuncSpy{} - m.On("doTick", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Run(func(args mock.Arguments) { - done1.Set(true) + mockFunc.On("doTick", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Once() + mockFunc.On("doTick", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { + doneErrSig.Signal() }).Once() - m.On("doTick", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { - done2.Set(true) - }).Once() - ticker = NewRetryTicker(m.doTick, longTime, backoff) + mockFunc.On("doTick", mock.Anything).Return(longTime, nil).Maybe() - ticker.Start() - defer ticker.Stop() - - // this should happen immediately, we add capTime to give some margin to make test more stable. - assert.True(t, PollWithTimeout(done1.Get, pollingInterval, capTime)) - - var expectedTimeToSecondAttempt time.Duration + schedulerSpy.On("afterFunc", time.Duration(0), mock.Anything).Return(nil).Once() if tc.firstErr == nil { - expectedTimeToSecondAttempt = tc.timeToSecondTick + schedulerSpy.On("afterFunc", tc.timeToSecondTick, mock.Anything).Return(nil).Once() } else { - expectedTimeToSecondAttempt = backoff.Cap + schedulerSpy.On("afterFunc", backoff.Duration, mock.Anything).Return(nil).Once() } - // we add capTime to give some margin to make test more stable. - assert.True(t, PollWithTimeout(done2.Get, pollingInterval, expectedTimeToSecondAttempt+capTime)) + schedulerSpy.On("afterFunc", longTime, mock.Anything).Return(nil).Maybe() + + newTicker := NewRetryTicker(mockFunc.doTick, longTime, backoff) + require.IsType(t, &retryTickerImpl{}, newTicker) + ticker := newTicker.(*retryTickerImpl) + ticker.scheduler = schedulerSpy.afterFunc + + ticker.Start() + defer ticker.Stop() + + _, ok := doneErrSig.WaitWithTimeout(testTimeout) + assert.True(t, ok) + mockFunc.AssertExpectations(t) + schedulerSpy.AssertExpectations(t) }) } } From b4626cff01bec851877bfc24a77fc0ec9d56322d Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 1 Feb 2022 14:56:59 +0100 Subject: [PATCH 29/34] fix bug where stopped timer was restarted by scheduleTick --- pkg/concurrency/retry_ticker.go | 12 ++++++++- pkg/concurrency/retry_ticker_test.go | 40 +++++++++++++++++++++------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index e46af2429eab2..e3522d6e0ff06 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -44,7 +44,7 @@ type retryTickerImpl struct { initialBackoff wait.Backoff backoff wait.Backoff timer *time.Timer - mutex sync.Mutex + mutex sync.RWMutex } // Start calls the tick function and schedules the next tick immediately. @@ -64,6 +64,10 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { defer cancel() nextTimeToTick, tickErr := t.doFunc(ctx) + if t.getTickTimer() == nil { + // timer was cancelled while tick function was running. + return + } if tickErr != nil { t.scheduleTick(t.backoff.Step()) return @@ -81,3 +85,9 @@ func (t *retryTickerImpl) setTickTimer(timer *time.Timer) { } t.timer = timer } + +func (t *retryTickerImpl) getTickTimer() *time.Timer { + t.mutex.RLock() + defer t.mutex.RUnlock() + return t.timer +} diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index 2653dfb3c653b..46627d54ff53e 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -48,7 +48,7 @@ func (f *afterFuncSpy) afterFunc(d time.Duration, fn func()) *time.Timer { return time.AfterFunc(d, fn) } -func TestRetryTicker(t *testing.T) { +func TestRetryTickerCallsTickFunction(t *testing.T) { testCases := map[string]struct { timeToSecondTick time.Duration firstErr error @@ -61,25 +61,20 @@ func TestRetryTicker(t *testing.T) { doneErrSig := NewErrorSignal() mockFunc := &testTickFunc{} schedulerSpy := &afterFuncSpy{} + ticker := newRetryTicker(t, mockFunc.doTick) + ticker.scheduler = schedulerSpy.afterFunc mockFunc.On("doTick", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Once() mockFunc.On("doTick", mock.Anything).Return(longTime, nil).Run(func(args mock.Arguments) { + ticker.Stop() doneErrSig.Signal() }).Once() - mockFunc.On("doTick", mock.Anything).Return(longTime, nil).Maybe() - schedulerSpy.On("afterFunc", time.Duration(0), mock.Anything).Return(nil).Once() if tc.firstErr == nil { schedulerSpy.On("afterFunc", tc.timeToSecondTick, mock.Anything).Return(nil).Once() } else { schedulerSpy.On("afterFunc", backoff.Duration, mock.Anything).Return(nil).Once() } - schedulerSpy.On("afterFunc", longTime, mock.Anything).Return(nil).Maybe() - - newTicker := NewRetryTicker(mockFunc.doTick, longTime, backoff) - require.IsType(t, &retryTickerImpl{}, newTicker) - ticker := newTicker.(*retryTickerImpl) - ticker.scheduler = schedulerSpy.afterFunc ticker.Start() defer ticker.Stop() @@ -91,3 +86,30 @@ func TestRetryTicker(t *testing.T) { }) } } + +func TestRetryTickerStop(t *testing.T) { + firsTickErrSig := NewErrorSignal() + stopErrSig := NewErrorSignal() + ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + firsTickErrSig.Signal() + _, ok := stopErrSig.WaitWithTimeout(testTimeout) + require.True(t, ok) + return capTime, nil + }) + + ticker.Start() + _, ok := firsTickErrSig.WaitWithTimeout(testTimeout) + require.True(t, ok) + ticker.Stop() + stopErrSig.Signal() + + // ensure `ticker.scheduleTick` does not schedule a new timer after stopping the ticker + time.Sleep(capTime) + assert.Nil(t, ticker.getTickTimer()) +} + +func newRetryTicker(t *testing.T, doFunc tickFunc) *retryTickerImpl { + ticker := NewRetryTicker(doFunc, longTime, backoff) + require.IsType(t, &retryTickerImpl{}, ticker) + return ticker.(*retryTickerImpl) +} From 7f8d96ef8f4427de768865f0125afbaf85584c6c Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Tue, 1 Feb 2022 14:59:15 +0100 Subject: [PATCH 30/34] reset backoff on start --- pkg/concurrency/retry_ticker.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index e3522d6e0ff06..5e5abddcd6ba1 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -49,6 +49,7 @@ type retryTickerImpl struct { // Start calls the tick function and schedules the next tick immediately. func (t *retryTickerImpl) Start() { + t.backoff = t.initialBackoff // reset backoff strategy t.scheduleTick(0) } @@ -65,7 +66,7 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { nextTimeToTick, tickErr := t.doFunc(ctx) if t.getTickTimer() == nil { - // timer was cancelled while tick function was running. + // ticker was stopped while tick function was running. return } if tickErr != nil { From 262d3c0745afe3c0d46f1f4a2ab46764f09cdf2e Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 3 Feb 2022 12:39:03 +0100 Subject: [PATCH 31/34] improve comments --- pkg/concurrency/retry_ticker.go | 2 +- pkg/concurrency/retry_ticker_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 5e5abddcd6ba1..d080d89096af4 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -49,7 +49,7 @@ type retryTickerImpl struct { // Start calls the tick function and schedules the next tick immediately. func (t *retryTickerImpl) Start() { - t.backoff = t.initialBackoff // reset backoff strategy + t.backoff = t.initialBackoff // initialize backoff strategy t.scheduleTick(0) } diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index 46627d54ff53e..e7852d0fbf1c9 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -80,7 +80,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { defer ticker.Stop() _, ok := doneErrSig.WaitWithTimeout(testTimeout) - assert.True(t, ok) + assert.True(t, ok, "timeout exceeded") mockFunc.AssertExpectations(t) schedulerSpy.AssertExpectations(t) }) @@ -99,7 +99,7 @@ func TestRetryTickerStop(t *testing.T) { ticker.Start() _, ok := firsTickErrSig.WaitWithTimeout(testTimeout) - require.True(t, ok) + require.True(t, ok, "timeout exceeded") ticker.Stop() stopErrSig.Signal() From 1c0c961ac5ad313b24fb12c6eb9d881af5143681 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 3 Feb 2022 15:23:23 +0100 Subject: [PATCH 32/34] only allow starting the ticker once - starting an started ticker leads to timer inteferences - starting an stopped ticker is not safe even in the same goroutine because a call to scheduleTick from a first Start can continue after a second Start is called, and that leads to concurrent modifications of t.backoff Also fix memory leak in test due to non stopped tickers --- pkg/concurrency/retry_ticker.go | 25 ++++++++++++++++++++--- pkg/concurrency/retry_ticker_test.go | 30 ++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index d080d89096af4..3acf4621d388b 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -4,17 +4,24 @@ import ( "context" "time" + "github.com/pkg/errors" "github.com/stackrox/rox/pkg/sync" "k8s.io/apimachinery/pkg/util/wait" ) var ( _ RetryTicker = (*retryTickerImpl)(nil) + // ErrStartedTimer is returned when Start is called on a timer that was already started. + ErrStartedTimer = errors.New("started timer") + // ErrStoppedTimer is returned when Start is called on a timer that was stopped. + ErrStoppedTimer = errors.New("stopped timer") ) // RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. +// RetryTickers can only be started once. +// RetryTickers are not safe for simultaneous use by multiple goroutines. type RetryTicker interface { - Start() + Start() error Stop() } @@ -45,17 +52,29 @@ type retryTickerImpl struct { backoff wait.Backoff timer *time.Timer mutex sync.RWMutex + stopFlag Flag } // Start calls the tick function and schedules the next tick immediately. -func (t *retryTickerImpl) Start() { +// Start returns and error if the RetryTicker is started more than once: +// - ErrStartedTimer is returned if the timer was already started. +// - ErrStoppedTimer is returned if the timer was stopped. +func (t *retryTickerImpl) Start() error { + if t.stopFlag.Get() { + return ErrStoppedTimer + } + if t.getTickTimer() != nil { + return ErrStartedTimer + } t.backoff = t.initialBackoff // initialize backoff strategy t.scheduleTick(0) + return nil } // Stop cancels this RetryTicker. If Stop is called while the tick function is running then Stop does not // wait for the tick function to complete before returning. func (t *retryTickerImpl) Stop() { + t.stopFlag.Set(true) t.setTickTimer(nil) } @@ -65,7 +84,7 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { defer cancel() nextTimeToTick, tickErr := t.doFunc(ctx) - if t.getTickTimer() == nil { + if t.stopFlag.Get() { // ticker was stopped while tick function was running. return } diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index e7852d0fbf1c9..a178b1e4c8856 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -62,6 +62,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { mockFunc := &testTickFunc{} schedulerSpy := &afterFuncSpy{} ticker := newRetryTicker(t, mockFunc.doTick) + defer ticker.Stop() ticker.scheduler = schedulerSpy.afterFunc mockFunc.On("doTick", mock.Anything).Return(tc.timeToSecondTick, tc.firstErr).Once() @@ -76,11 +77,10 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { schedulerSpy.On("afterFunc", backoff.Duration, mock.Anything).Return(nil).Once() } - ticker.Start() - defer ticker.Stop() + require.NoError(t, ticker.Start()) _, ok := doneErrSig.WaitWithTimeout(testTimeout) - assert.True(t, ok, "timeout exceeded") + require.True(t, ok, "timeout exceeded") mockFunc.AssertExpectations(t) schedulerSpy.AssertExpectations(t) }) @@ -96,8 +96,9 @@ func TestRetryTickerStop(t *testing.T) { require.True(t, ok) return capTime, nil }) + defer ticker.Stop() - ticker.Start() + require.NoError(t, ticker.Start()) _, ok := firsTickErrSig.WaitWithTimeout(testTimeout) require.True(t, ok, "timeout exceeded") ticker.Stop() @@ -108,6 +109,27 @@ func TestRetryTickerStop(t *testing.T) { assert.Nil(t, ticker.getTickTimer()) } +func TestRetryTickerStartWhileStarterFailure(t *testing.T) { + ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + return 0, nil + }) + defer ticker.Stop() + + require.NoError(t, ticker.Start()) + assert.ErrorIs(t, ErrStartedTimer, ticker.Start()) +} + +func TestRetryTickerStartTwiceFailure(t *testing.T) { + ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + return 0, nil + }) + defer ticker.Stop() + + require.NoError(t, ticker.Start()) + ticker.Stop() + require.ErrorIs(t, ErrStoppedTimer, ticker.Start()) +} + func newRetryTicker(t *testing.T, doFunc tickFunc) *retryTickerImpl { ticker := NewRetryTicker(doFunc, longTime, backoff) require.IsType(t, &retryTickerImpl{}, ticker) From b6f69a565a6a47a00c1ae1f3329c7ad9fa9abe56 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 3 Feb 2022 15:50:54 +0100 Subject: [PATCH 33/34] simplify ticker interface by removing Start method NewRetryTicker already returns a started ticker. That ensures tickers can only be started once --- pkg/concurrency/retry_ticker.go | 40 ++++++++++------------------ pkg/concurrency/retry_ticker_test.go | 35 +++++------------------- 2 files changed, 21 insertions(+), 54 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 3acf4621d388b..f10880468d95d 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -4,44 +4,44 @@ import ( "context" "time" - "github.com/pkg/errors" "github.com/stackrox/rox/pkg/sync" "k8s.io/apimachinery/pkg/util/wait" ) var ( _ RetryTicker = (*retryTickerImpl)(nil) - // ErrStartedTimer is returned when Start is called on a timer that was already started. - ErrStartedTimer = errors.New("started timer") - // ErrStoppedTimer is returned when Start is called on a timer that was stopped. - ErrStoppedTimer = errors.New("stopped timer") ) // RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. -// RetryTickers can only be started once. -// RetryTickers are not safe for simultaneous use by multiple goroutines. type RetryTicker interface { - Start() error Stop() } type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) // NewRetryTicker returns a new RetryTicker that calls the "tick function" `doFunc` repeatedly: -// - When started, the RetryTicker calls `doFunc` immediately, and if that returns an error +// - The RetryTicker calls `doFunc` immediately, and if that returns an error // then the RetryTicker will wait the time returned by `backoff.Step` before calling `doFunc` again. // - `doFunc` should return an error if ctx is cancelled. RetryTicker always calls `doFunc` with a context // with a timeout of `timeout`. // - On success `RetryTicker` will reset `backoff`, and wait the amount of time returned by `doFunc` before // running it again. func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { - return &retryTickerImpl{ + return newRetryTicker(doFunc, timeout, backoff, true) +} + +func newRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff, start bool) RetryTicker { + ticker := &retryTickerImpl{ scheduler: time.AfterFunc, doFunc: doFunc, timeout: timeout, initialBackoff: backoff, backoff: backoff, } + if start { + ticker.start() + } + return ticker } type retryTickerImpl struct { @@ -52,29 +52,16 @@ type retryTickerImpl struct { backoff wait.Backoff timer *time.Timer mutex sync.RWMutex - stopFlag Flag } // Start calls the tick function and schedules the next tick immediately. -// Start returns and error if the RetryTicker is started more than once: -// - ErrStartedTimer is returned if the timer was already started. -// - ErrStoppedTimer is returned if the timer was stopped. -func (t *retryTickerImpl) Start() error { - if t.stopFlag.Get() { - return ErrStoppedTimer - } - if t.getTickTimer() != nil { - return ErrStartedTimer - } - t.backoff = t.initialBackoff // initialize backoff strategy +func (t *retryTickerImpl) start() { t.scheduleTick(0) - return nil } // Stop cancels this RetryTicker. If Stop is called while the tick function is running then Stop does not // wait for the tick function to complete before returning. func (t *retryTickerImpl) Stop() { - t.stopFlag.Set(true) t.setTickTimer(nil) } @@ -84,7 +71,7 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { defer cancel() nextTimeToTick, tickErr := t.doFunc(ctx) - if t.stopFlag.Get() { + if t.getTickTimer() == nil { // ticker was stopped while tick function was running. return } @@ -92,7 +79,8 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { t.scheduleTick(t.backoff.Step()) return } - t.backoff = t.initialBackoff // reset backoff strategy + // reset backoff strategy. + t.backoff = t.initialBackoff t.scheduleTick(nextTimeToTick) })) } diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index a178b1e4c8856..beae3ce6b78c2 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -61,7 +61,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { doneErrSig := NewErrorSignal() mockFunc := &testTickFunc{} schedulerSpy := &afterFuncSpy{} - ticker := newRetryTicker(t, mockFunc.doTick) + ticker := newTestRetryTicker(t, mockFunc.doTick) defer ticker.Stop() ticker.scheduler = schedulerSpy.afterFunc @@ -77,7 +77,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { schedulerSpy.On("afterFunc", backoff.Duration, mock.Anything).Return(nil).Once() } - require.NoError(t, ticker.Start()) + ticker.start() _, ok := doneErrSig.WaitWithTimeout(testTimeout) require.True(t, ok, "timeout exceeded") @@ -90,7 +90,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { func TestRetryTickerStop(t *testing.T) { firsTickErrSig := NewErrorSignal() stopErrSig := NewErrorSignal() - ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + ticker := newTestRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { firsTickErrSig.Signal() _, ok := stopErrSig.WaitWithTimeout(testTimeout) require.True(t, ok) @@ -98,40 +98,19 @@ func TestRetryTickerStop(t *testing.T) { }) defer ticker.Stop() - require.NoError(t, ticker.Start()) + ticker.start() _, ok := firsTickErrSig.WaitWithTimeout(testTimeout) require.True(t, ok, "timeout exceeded") ticker.Stop() stopErrSig.Signal() - // ensure `ticker.scheduleTick` does not schedule a new timer after stopping the ticker + // ensure `ticker.scheduleTick` does not schedule a new timer after stopping the ticker. time.Sleep(capTime) assert.Nil(t, ticker.getTickTimer()) } -func TestRetryTickerStartWhileStarterFailure(t *testing.T) { - ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { - return 0, nil - }) - defer ticker.Stop() - - require.NoError(t, ticker.Start()) - assert.ErrorIs(t, ErrStartedTimer, ticker.Start()) -} - -func TestRetryTickerStartTwiceFailure(t *testing.T) { - ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { - return 0, nil - }) - defer ticker.Stop() - - require.NoError(t, ticker.Start()) - ticker.Stop() - require.ErrorIs(t, ErrStoppedTimer, ticker.Start()) -} - -func newRetryTicker(t *testing.T, doFunc tickFunc) *retryTickerImpl { - ticker := NewRetryTicker(doFunc, longTime, backoff) +func newTestRetryTicker(t *testing.T, doFunc tickFunc) *retryTickerImpl { + ticker := newRetryTicker(doFunc, longTime, backoff, false) require.IsType(t, &retryTickerImpl{}, ticker) return ticker.(*retryTickerImpl) } From d9fa4380a8d624e4dd583c8f8303a0fd7478d465 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Thu, 3 Feb 2022 15:58:53 +0100 Subject: [PATCH 34/34] Revert "simplify ticker interface by removing Start method" This reverts commit b6f69a565a6a47a00c1ae1f3329c7ad9fa9abe56. --- pkg/concurrency/retry_ticker.go | 40 ++++++++++++++++++---------- pkg/concurrency/retry_ticker_test.go | 35 +++++++++++++++++++----- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index f10880468d95d..3acf4621d388b 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -4,44 +4,44 @@ import ( "context" "time" + "github.com/pkg/errors" "github.com/stackrox/rox/pkg/sync" "k8s.io/apimachinery/pkg/util/wait" ) var ( _ RetryTicker = (*retryTickerImpl)(nil) + // ErrStartedTimer is returned when Start is called on a timer that was already started. + ErrStartedTimer = errors.New("started timer") + // ErrStoppedTimer is returned when Start is called on a timer that was stopped. + ErrStoppedTimer = errors.New("stopped timer") ) // RetryTicker repeatedly calls a function with a timeout and a retry backoff strategy. +// RetryTickers can only be started once. +// RetryTickers are not safe for simultaneous use by multiple goroutines. type RetryTicker interface { + Start() error Stop() } type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) // NewRetryTicker returns a new RetryTicker that calls the "tick function" `doFunc` repeatedly: -// - The RetryTicker calls `doFunc` immediately, and if that returns an error +// - When started, the RetryTicker calls `doFunc` immediately, and if that returns an error // then the RetryTicker will wait the time returned by `backoff.Step` before calling `doFunc` again. // - `doFunc` should return an error if ctx is cancelled. RetryTicker always calls `doFunc` with a context // with a timeout of `timeout`. // - On success `RetryTicker` will reset `backoff`, and wait the amount of time returned by `doFunc` before // running it again. func NewRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff) RetryTicker { - return newRetryTicker(doFunc, timeout, backoff, true) -} - -func newRetryTicker(doFunc tickFunc, timeout time.Duration, backoff wait.Backoff, start bool) RetryTicker { - ticker := &retryTickerImpl{ + return &retryTickerImpl{ scheduler: time.AfterFunc, doFunc: doFunc, timeout: timeout, initialBackoff: backoff, backoff: backoff, } - if start { - ticker.start() - } - return ticker } type retryTickerImpl struct { @@ -52,16 +52,29 @@ type retryTickerImpl struct { backoff wait.Backoff timer *time.Timer mutex sync.RWMutex + stopFlag Flag } // Start calls the tick function and schedules the next tick immediately. -func (t *retryTickerImpl) start() { +// Start returns and error if the RetryTicker is started more than once: +// - ErrStartedTimer is returned if the timer was already started. +// - ErrStoppedTimer is returned if the timer was stopped. +func (t *retryTickerImpl) Start() error { + if t.stopFlag.Get() { + return ErrStoppedTimer + } + if t.getTickTimer() != nil { + return ErrStartedTimer + } + t.backoff = t.initialBackoff // initialize backoff strategy t.scheduleTick(0) + return nil } // Stop cancels this RetryTicker. If Stop is called while the tick function is running then Stop does not // wait for the tick function to complete before returning. func (t *retryTickerImpl) Stop() { + t.stopFlag.Set(true) t.setTickTimer(nil) } @@ -71,7 +84,7 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { defer cancel() nextTimeToTick, tickErr := t.doFunc(ctx) - if t.getTickTimer() == nil { + if t.stopFlag.Get() { // ticker was stopped while tick function was running. return } @@ -79,8 +92,7 @@ func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { t.scheduleTick(t.backoff.Step()) return } - // reset backoff strategy. - t.backoff = t.initialBackoff + t.backoff = t.initialBackoff // reset backoff strategy t.scheduleTick(nextTimeToTick) })) } diff --git a/pkg/concurrency/retry_ticker_test.go b/pkg/concurrency/retry_ticker_test.go index beae3ce6b78c2..a178b1e4c8856 100644 --- a/pkg/concurrency/retry_ticker_test.go +++ b/pkg/concurrency/retry_ticker_test.go @@ -61,7 +61,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { doneErrSig := NewErrorSignal() mockFunc := &testTickFunc{} schedulerSpy := &afterFuncSpy{} - ticker := newTestRetryTicker(t, mockFunc.doTick) + ticker := newRetryTicker(t, mockFunc.doTick) defer ticker.Stop() ticker.scheduler = schedulerSpy.afterFunc @@ -77,7 +77,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { schedulerSpy.On("afterFunc", backoff.Duration, mock.Anything).Return(nil).Once() } - ticker.start() + require.NoError(t, ticker.Start()) _, ok := doneErrSig.WaitWithTimeout(testTimeout) require.True(t, ok, "timeout exceeded") @@ -90,7 +90,7 @@ func TestRetryTickerCallsTickFunction(t *testing.T) { func TestRetryTickerStop(t *testing.T) { firsTickErrSig := NewErrorSignal() stopErrSig := NewErrorSignal() - ticker := newTestRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { firsTickErrSig.Signal() _, ok := stopErrSig.WaitWithTimeout(testTimeout) require.True(t, ok) @@ -98,19 +98,40 @@ func TestRetryTickerStop(t *testing.T) { }) defer ticker.Stop() - ticker.start() + require.NoError(t, ticker.Start()) _, ok := firsTickErrSig.WaitWithTimeout(testTimeout) require.True(t, ok, "timeout exceeded") ticker.Stop() stopErrSig.Signal() - // ensure `ticker.scheduleTick` does not schedule a new timer after stopping the ticker. + // ensure `ticker.scheduleTick` does not schedule a new timer after stopping the ticker time.Sleep(capTime) assert.Nil(t, ticker.getTickTimer()) } -func newTestRetryTicker(t *testing.T, doFunc tickFunc) *retryTickerImpl { - ticker := newRetryTicker(doFunc, longTime, backoff, false) +func TestRetryTickerStartWhileStarterFailure(t *testing.T) { + ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + return 0, nil + }) + defer ticker.Stop() + + require.NoError(t, ticker.Start()) + assert.ErrorIs(t, ErrStartedTimer, ticker.Start()) +} + +func TestRetryTickerStartTwiceFailure(t *testing.T) { + ticker := newRetryTicker(t, func(ctx context.Context) (timeToNextTick time.Duration, err error) { + return 0, nil + }) + defer ticker.Stop() + + require.NoError(t, ticker.Start()) + ticker.Stop() + require.ErrorIs(t, ErrStoppedTimer, ticker.Start()) +} + +func newRetryTicker(t *testing.T, doFunc tickFunc) *retryTickerImpl { + ticker := NewRetryTicker(doFunc, longTime, backoff) require.IsType(t, &retryTickerImpl{}, ticker) return ticker.(*retryTickerImpl) }