diff --git a/image/templates/helm/stackrox-secured-cluster/templates/sensor.yaml.htpl b/image/templates/helm/stackrox-secured-cluster/templates/sensor.yaml.htpl index 5cc349847c53e..77b4003811b5b 100644 --- a/image/templates/helm/stackrox-secured-cluster/templates/sensor.yaml.htpl +++ b/image/templates/helm/stackrox-secured-cluster/templates/sensor.yaml.htpl @@ -108,6 +108,10 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name - name: ROX_CENTRAL_ENDPOINT value: {{ ._rox.centralEndpoint }} - name: ROX_ADVERTISED_ENDPOINT diff --git a/pkg/centralsensor/caps_list.go b/pkg/centralsensor/caps_list.go index e31cd51e8a556..e0b85406eee2b 100644 --- a/pkg/centralsensor/caps_list.go +++ b/pkg/centralsensor/caps_list.go @@ -24,4 +24,7 @@ const ( // AuditLogEventsCap identifies the capability to handle audit log event detection. AuditLogEventsCap SensorCapability = "AuditLogEvents" + + // LocalScannerCredentialsRefresh identifies the capability to maintain the Local scanner TLS credentials refreshed. + LocalScannerCredentialsRefresh SensorCapability = "LocalScannerCredentialsRefresh" ) diff --git a/pkg/concurrency/retry_ticker.go b/pkg/concurrency/retry_ticker.go index 58e767a071b67..1505a28548909 100644 --- a/pkg/concurrency/retry_ticker.go +++ b/pkg/concurrency/retry_ticker.go @@ -25,6 +25,7 @@ var ( type RetryTicker interface { Start() error Stop() + Stopped() bool } type tickFunc func(ctx context.Context) (timeToNextTick time.Duration, err error) @@ -64,7 +65,7 @@ type retryTickerImpl struct { // - ErrStartedTimer is returned if the timer was already started. // - ErrStoppedTimer is returned if the timer was stopped. func (t *retryTickerImpl) Start() error { - if t.stopFlag.Get() { + if t.Stopped() { return ErrStoppedTimer } if t.getTickTimer() != nil { @@ -82,13 +83,18 @@ func (t *retryTickerImpl) Stop() { t.setTickTimer(nil) } +// Stopped returns true if this RetryTicker has been stopped, otherwise returns false. +func (t *retryTickerImpl) Stopped() bool { + return t.stopFlag.Get() +} + func (t *retryTickerImpl) scheduleTick(timeToTick time.Duration) { t.setTickTimer(t.scheduler(timeToTick, func() { ctx, cancel := context.WithTimeout(context.Background(), t.timeout) defer cancel() nextTimeToTick, tickErr := t.doFunc(ctx) - if t.stopFlag.Get() { + if t.Stopped() { // ticker was stopped while tick function was running. return } diff --git a/sensor/kubernetes/localscanner/cert_refresher.go b/sensor/kubernetes/localscanner/cert_refresher.go index 605aa9395ecae..bf84bfe8866a9 100644 --- a/sensor/kubernetes/localscanner/cert_refresher.go +++ b/sensor/kubernetes/localscanner/cert_refresher.go @@ -47,7 +47,7 @@ func refreshCertificates(ctx context.Context, requestCertificates requestCertifi timeToNextRefresh, err = ensureCertificatesAreFresh(ctx, requestCertificates, getCertsRenewalTime, repository) if err != nil { if errors.Is(err, ErrUnexpectedSecretsOwner) { - log.Errorf("stopping automatic refresh of %s: %s", certsDescription, err) + log.Errorf("non-recoverable error refreshing %s, automatic refresh will be stopped: %s", certsDescription, err) return 0, concurrency.ErrNonRecoverable } @@ -97,7 +97,10 @@ func getTimeToRefreshFromRepo(ctx context.Context, getCertsRenewalTime getCertsR repository serviceCertificatesRepo) (time.Duration, error) { certificates, getCertsErr := repository.getServiceCertificates(ctx) - if getCertsErr == ErrDifferentCAForDifferentServiceTypes || getCertsErr == ErrMissingSecretData { + if errors.Is(getCertsErr, ErrUnexpectedSecretsOwner) { + return 0, getCertsErr + } + if errors.Is(getCertsErr, ErrDifferentCAForDifferentServiceTypes) || errors.Is(getCertsErr, ErrMissingSecretData) { log.Errorf("local scanner certificates are in an inconsistent state, "+ "will refresh certificates immediately: %s", getCertsErr) return 0, nil diff --git a/sensor/kubernetes/localscanner/cert_refresher_test.go b/sensor/kubernetes/localscanner/cert_refresher_test.go index 28b77a576be33..d33e8dad707dc 100644 --- a/sensor/kubernetes/localscanner/cert_refresher_test.go +++ b/sensor/kubernetes/localscanner/cert_refresher_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/hashicorp/go-multierror" "github.com/pkg/errors" "github.com/stackrox/rox/generated/internalapi/central" "github.com/stackrox/rox/generated/storage" @@ -110,8 +111,8 @@ func (s *certRefresherSuite) TestRefreshCertificatesGetCertsInconsistentImmediat testCases := map[string]struct { recoverableErr error }{ - "refresh immediately on ErrDifferentCAForDifferentServiceTypes": {recoverableErr: ErrDifferentCAForDifferentServiceTypes}, - "refresh immediately on ErrMissingSecretData": {recoverableErr: ErrMissingSecretData}, + "refresh immediately on ErrDifferentCAForDifferentServiceTypes": {recoverableErr: errors.Wrap(ErrDifferentCAForDifferentServiceTypes, "wrap error")}, + "refresh immediately on ErrMissingSecretData": {recoverableErr: errors.Wrap(ErrMissingSecretData, "wrap error")}, "refresh immediately on missing secrets": {recoverableErr: k8sErrors.NewNotFound(schema.GroupResource{Group: "Core", Resource: "Secret"}, "foo")}, } for tcName, tc := range testCases { @@ -132,9 +133,10 @@ func (s *certRefresherSuite) TestRefreshCertificatesGetCertsInconsistentImmediat } } -func (s *certRefresherSuite) TestRefreshCertificatesGetCertsUnexpectedOwnerFailure() { +func (s *certRefresherSuite) TestRefreshCertificatesGetCertsUnexpectedOwnerHighestPriorityFailure() { + getErr := multierror.Append(nil, ErrUnexpectedSecretsOwner, ErrDifferentCAForDifferentServiceTypes, ErrMissingSecretData) s.dependenciesMock.On("getServiceCertificates", mock.Anything).Once().Return( - (*storage.TypedServiceCertificateSet)(nil), concurrency.ErrNonRecoverable) + (*storage.TypedServiceCertificateSet)(nil), getErr) _, err := s.refreshCertificates() diff --git a/sensor/kubernetes/localscanner/certificate_expiration_test.go b/sensor/kubernetes/localscanner/certificate_expiration_test.go index 3d0a9365ca795..5e42021302c60 100644 --- a/sensor/kubernetes/localscanner/certificate_expiration_test.go +++ b/sensor/kubernetes/localscanner/certificate_expiration_test.go @@ -68,15 +68,23 @@ func (s *getSecretRenewalTimeSuite) TestGetSecretsCertRenewalTime() { s.LessOrEqual(certDuration, afterOffset/2) } -func issueCertificatePEM(issueOption mtls.IssueCertOption) ([]byte, error) { +func issueCertificate(serviceType storage.ServiceType, issueOption mtls.IssueCertOption) (*mtls.IssuedCert, error) { ca, err := mtls.CAForSigning() if err != nil { return nil, err } - subject := mtls.NewSubject("clusterId", storage.ServiceType_SCANNER_SERVICE) + subject := mtls.NewSubject("clusterId", serviceType) cert, err := ca.IssueCertForSubject(subject, issueOption) if err != nil { return nil, err } - return cert.CertPEM, err + return cert, err +} + +func issueCertificatePEM(issueOption mtls.IssueCertOption) ([]byte, error) { + cert, err := issueCertificate(storage.ServiceType_SCANNER_SERVICE, issueOption) + if err != nil { + return nil, err + } + return cert.CertPEM, nil } diff --git a/sensor/kubernetes/localscanner/certificate_requester.go b/sensor/kubernetes/localscanner/certificate_requester.go index 4de43ad69a69e..b2cb4252f19bb 100644 --- a/sensor/kubernetes/localscanner/certificate_requester.go +++ b/sensor/kubernetes/localscanner/certificate_requester.go @@ -6,7 +6,6 @@ import ( "github.com/pkg/errors" "github.com/stackrox/rox/generated/internalapi/central" "github.com/stackrox/rox/pkg/concurrency" - "github.com/stackrox/rox/pkg/logging" "github.com/stackrox/rox/pkg/sync" "github.com/stackrox/rox/pkg/uuid" ) @@ -15,17 +14,9 @@ var ( // ErrCertificateRequesterStopped is returned by RequestCertificates when the certificate // requester is not initialized. ErrCertificateRequesterStopped = errors.New("stopped") - log = logging.LoggerForModule() _ CertificateRequester = (*certificateRequesterImpl)(nil) ) -// CertificateRequester requests a new set of local scanner certificates from central. -type CertificateRequester interface { - Start() - Stop() - RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) -} - // NewCertificateRequester creates a new certificate requester that communicates through // the specified channels and initializes a new request ID for reach request. // To use it call Start, and then make requests with RequestCertificates, concurrent requests are supported. diff --git a/sensor/kubernetes/localscanner/service_certificates_repository.go b/sensor/kubernetes/localscanner/service_certificates_repository.go index c369a3087594c..b369236ba532d 100644 --- a/sensor/kubernetes/localscanner/service_certificates_repository.go +++ b/sensor/kubernetes/localscanner/service_certificates_repository.go @@ -53,7 +53,7 @@ type serviceCertSecretSpec struct { // newServiceCertificatesRepo creates a new serviceCertificatesRepoSecretsImpl that persists certificates for // scanner and scanner DB in k8s secrets that are expected to have ownerReference as the only owner reference. func newServiceCertificatesRepo(ownerReference metav1.OwnerReference, namespace string, - secretsClient corev1.SecretInterface) *serviceCertificatesRepoSecretsImpl { + secretsClient corev1.SecretInterface) serviceCertificatesRepo { return &serviceCertificatesRepoSecretsImpl{ secrets: map[storage.ServiceType]serviceCertSecretSpec{ diff --git a/sensor/kubernetes/localscanner/tls_issuer.go b/sensor/kubernetes/localscanner/tls_issuer.go new file mode 100644 index 0000000000000..295bfc06bf501 --- /dev/null +++ b/sensor/kubernetes/localscanner/tls_issuer.go @@ -0,0 +1,247 @@ +package localscanner + +import ( + "context" + "time" + + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/internalapi/central" + "github.com/stackrox/rox/pkg/centralsensor" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/sensor/common" + k8sErrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/util/retry" +) + +var ( + log = logging.LoggerForModule() + + startTimeout = 6 * time.Minute + fetchSensorDeploymentOwnerRefBackoff = wait.Backoff{ + Duration: 10 * time.Millisecond, + Factor: 3, + Jitter: 0.1, + Steps: 10, + Cap: startTimeout, + } + processMessageTimeout = 5 * time.Second + certRefreshTimeout = 5 * time.Minute + certRefreshBackoff = wait.Backoff{ + Duration: 5 * time.Second, + Factor: 3.0, + Jitter: 0.1, + Steps: 5, + Cap: 10 * time.Minute, + } + _ common.SensorComponent = (*localScannerTLSIssuerImpl)(nil) +) + +// NewLocalScannerTLSIssuer creates a sensor component that will keep the local scanner certificates +// up to date, using the specified retry parameters. +func NewLocalScannerTLSIssuer( + k8sClient kubernetes.Interface, + sensorNamespace string, + sensorPodName string, +) common.SensorComponent { + msgToCentralC := make(chan *central.MsgFromSensor) + msgFromCentralC := make(chan *central.IssueLocalScannerCertsResponse) + return &localScannerTLSIssuerImpl{ + sensorNamespace: sensorNamespace, + sensorPodName: sensorPodName, + k8sClient: k8sClient, + msgToCentralC: msgToCentralC, + msgFromCentralC: msgFromCentralC, + certRefreshBackoff: certRefreshBackoff, + getCertificateRefresherFn: newCertificatesRefresher, + getServiceCertificatesRepoFn: newServiceCertificatesRepo, + certRequester: NewCertificateRequester(msgToCentralC, msgFromCentralC), + } +} + +type localScannerTLSIssuerImpl struct { + sensorNamespace string + sensorPodName string + k8sClient kubernetes.Interface + msgToCentralC chan *central.MsgFromSensor + msgFromCentralC chan *central.IssueLocalScannerCertsResponse + certRefreshBackoff wait.Backoff + getCertificateRefresherFn certificateRefresherGetter + getServiceCertificatesRepoFn serviceCertificatesRepoGetter + certRequester CertificateRequester + certRefresher concurrency.RetryTicker +} + +// CertificateRequester requests a new set of local scanner certificates from central. +type CertificateRequester interface { + Start() + Stop() + RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) +} + +type certificateRefresherGetter func(requestCertificates requestCertificatesFunc, repository serviceCertificatesRepo, + timeout time.Duration, backoff wait.Backoff) concurrency.RetryTicker + +type serviceCertificatesRepoGetter func(ownerReference metav1.OwnerReference, namespace string, + secretsClient corev1.SecretInterface) serviceCertificatesRepo + +// Start starts the sensor component and launches a certificate refreshes that immediately checks the certificates, and +// that keeps them updated. +// In case a secret doesn't have the expected owner, this logs a warning and returns nil. +// In case this component was already started it fails immediately. +func (i *localScannerTLSIssuerImpl) Start() error { + log.Debug("starting local scanner TLS issuer.") + ctx, cancel := context.WithTimeout(context.Background(), startTimeout) + defer cancel() + + if i.certRefresher != nil { + return i.abortStart(errors.New("already started.")) + } + + sensorOwnerReference, fetchSensorDeploymentErr := i.fetchSensorDeploymentOwnerRef(ctx, fetchSensorDeploymentOwnerRefBackoff) + if fetchSensorDeploymentErr != nil { + return i.abortStart(errors.Wrap(fetchSensorDeploymentErr, "fetching sensor deployment")) + } + + certsRepo := i.getServiceCertificatesRepoFn(*sensorOwnerReference, i.sensorNamespace, + i.k8sClient.CoreV1().Secrets(i.sensorNamespace)) + i.certRefresher = i.getCertificateRefresherFn(i.certRequester.RequestCertificates, certsRepo, + certRefreshTimeout, i.certRefreshBackoff) + + i.certRequester.Start() + if refreshStartErr := i.certRefresher.Start(); refreshStartErr != nil { + return i.abortStart(errors.Wrap(refreshStartErr, "starting certificate certRefresher")) + } + + log.Debug("local scanner TLS issuer started.") + return nil +} + +func (i *localScannerTLSIssuerImpl) abortStart(err error) error { + log.Errorf("local scanner TLS issuer start aborted due to error: %s", err) + i.Stop(err) + // This component should never stop Sensor. + return nil +} + +func (i *localScannerTLSIssuerImpl) Stop(_ error) { + if i.certRefresher != nil { + i.certRefresher.Stop() + i.certRefresher = nil + } + + i.certRequester.Stop() + log.Debug("local scanner TLS issuer stopped.") +} + +func (i *localScannerTLSIssuerImpl) Capabilities() []centralsensor.SensorCapability { + return []centralsensor.SensorCapability{centralsensor.LocalScannerCredentialsRefresh} +} + +// ResponsesC is called "responses" because for other SensorComponent it is central that +// initiates the interaction. However, here it is sensor which sends a request to central. +func (i *localScannerTLSIssuerImpl) ResponsesC() <-chan *central.MsgFromSensor { + return i.msgToCentralC +} + +// ProcessMessage dispatches Central's messages to Sensor received via the central receiver. +// This method must not block as it would prevent centralReceiverImpl from sending messages +// to other SensorComponents. +func (i *localScannerTLSIssuerImpl) ProcessMessage(msg *central.MsgToSensor) error { + switch m := msg.GetMsg().(type) { + case *central.MsgToSensor_IssueLocalScannerCertsResponse: + response := m.IssueLocalScannerCertsResponse + go func() { + ctx, cancel := context.WithTimeout(context.Background(), processMessageTimeout) + defer cancel() + select { + case <-ctx.Done(): + // certRefresher will retry. + log.Errorf("timeout forwarding response %s from central: %s", response, ctx.Err()) + case i.msgFromCentralC <- response: + } + }() + return nil + default: + // messages not supported by this component are ignored because unknown messages types are handled by the central receiver. + return nil + } +} + +func (i *localScannerTLSIssuerImpl) fetchSensorDeploymentOwnerRef(ctx context.Context, backoff wait.Backoff) (*metav1.OwnerReference, error) { + if i.sensorPodName == "" { + return nil, errors.New("fetching sensor deployment: empty pod name") + } + + podsClient := i.k8sClient.CoreV1().Pods(i.sensorNamespace) + sensorPodMeta, getPodErr := i.getObjectMetaWithRetries(ctx, backoff, func(ctx context.Context) (metav1.Object, error) { + pod, err := podsClient.Get(ctx, i.sensorPodName, metav1.GetOptions{}) + if err != nil { + return nil, err + } + return pod.GetObjectMeta(), nil + }) + if getPodErr != nil { + return nil, errors.Wrapf(getPodErr, "fetching sensor pod with name %q", i.sensorPodName) + } + podOwners := sensorPodMeta.GetOwnerReferences() + if len(podOwners) != 1 { + return nil, errors.Errorf("pod %q has unexpected owners %v", + i.sensorPodName, podOwners) + } + podOwnerName := podOwners[0].Name + + replicaSetClient := i.k8sClient.AppsV1().ReplicaSets(i.sensorNamespace) + ownerReplicaSetMeta, getReplicaSetErr := i.getObjectMetaWithRetries(ctx, backoff, + func(ctx context.Context) (metav1.Object, error) { + replicaSet, err := replicaSetClient.Get(ctx, podOwnerName, metav1.GetOptions{}) + if err != nil { + return nil, err + } + return replicaSet.GetObjectMeta(), nil + }) + if getReplicaSetErr != nil { + return nil, errors.Wrapf(getReplicaSetErr, "fetching owner replica set with name %q", podOwnerName) + } + replicaSetOwners := ownerReplicaSetMeta.GetOwnerReferences() + if len(replicaSetOwners) != 1 { + return nil, errors.Errorf("replica set %q has unexpected owners %v", + ownerReplicaSetMeta.GetName(), + replicaSetOwners) + } + replicaSetOwner := replicaSetOwners[0] + + blockOwnerDeletion := false + isController := false + return &metav1.OwnerReference{ + APIVersion: replicaSetOwner.APIVersion, + Kind: replicaSetOwner.Kind, + Name: replicaSetOwner.Name, + UID: replicaSetOwner.UID, + BlockOwnerDeletion: &blockOwnerDeletion, + Controller: &isController, + }, nil +} + +func (i *localScannerTLSIssuerImpl) getObjectMetaWithRetries( + ctx context.Context, + backoff wait.Backoff, + getObject func(context.Context) (metav1.Object, error), +) (metav1.Object, error) { + var object metav1.Object + getErr := retry.OnError(backoff, func(err error) bool { + return !k8sErrors.IsNotFound(err) + }, func() error { + newObject, err := getObject(ctx) + if err == nil { + object = newObject + } + return err + }) + + return object, getErr +} diff --git a/sensor/kubernetes/localscanner/tls_issuer_test.go b/sensor/kubernetes/localscanner/tls_issuer_test.go new file mode 100644 index 0000000000000..7dd018af8f236 --- /dev/null +++ b/sensor/kubernetes/localscanner/tls_issuer_test.go @@ -0,0 +1,577 @@ +package localscanner + +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/internalapi/central" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/concurrency" + "github.com/stackrox/rox/pkg/mtls" + testutilsMTLS "github.com/stackrox/rox/pkg/mtls/testutils" + "github.com/stackrox/rox/pkg/testutils/envisolator" + "github.com/stackrox/rox/pkg/uuid" + "github.com/stackrox/rox/sensor/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + appsApiv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + k8sErrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +var ( + sensorNamespace = "stackrox-ns" + sensorReplicasetName = "sensor-replicaset" + sensorPodName = "sensor-pod" +) + +type localScannerTLSIssuerFixture struct { + k8sClient *fake.Clientset + certRequester *certificateRequesterMock + certRefresher *certificateRefresherMock + repo *certsRepoMock + componentGetter *componentGetterMock + tlsIssuer *localScannerTLSIssuerImpl +} + +func newLocalScannerTLSIssuerFixture(k8sClientConfig fakeK8sClientConfig) *localScannerTLSIssuerFixture { + fixture := &localScannerTLSIssuerFixture{ + certRequester: &certificateRequesterMock{}, + certRefresher: &certificateRefresherMock{}, + repo: &certsRepoMock{}, + componentGetter: &componentGetterMock{}, + k8sClient: getFakeK8sClient(k8sClientConfig), + } + msgToCentralC := make(chan *central.MsgFromSensor) + msgFromCentralC := make(chan *central.IssueLocalScannerCertsResponse) + fixture.tlsIssuer = &localScannerTLSIssuerImpl{ + sensorNamespace: sensorNamespace, + sensorPodName: sensorPodName, + k8sClient: fixture.k8sClient, + msgToCentralC: msgToCentralC, + msgFromCentralC: msgFromCentralC, + certRefreshBackoff: certRefreshBackoff, + getCertificateRefresherFn: fixture.componentGetter.getCertificateRefresher, + getServiceCertificatesRepoFn: fixture.componentGetter.getServiceCertificatesRepo, + certRequester: fixture.certRequester, + } + + return fixture +} + +func (f *localScannerTLSIssuerFixture) assertMockExpectations(t *testing.T) { + f.certRequester.AssertExpectations(t) + f.certRequester.AssertExpectations(t) + f.componentGetter.AssertExpectations(t) +} + +// mockForStart setups the mocks for the happy path of Start +func (f *localScannerTLSIssuerFixture) mockForStart(conf mockForStartConfig) { + f.certRequester.On("Start").Once() + f.certRefresher.On("Start").Once().Return(conf.refresherStartErr) + + f.repo.On("getServiceCertificates", mock.Anything).Once(). + Return((*storage.TypedServiceCertificateSet)(nil), conf.getCertsErr) + + f.componentGetter.On("getServiceCertificatesRepo", mock.Anything, + mock.Anything, mock.Anything).Once().Return(f.repo, nil) + + f.componentGetter.On("getCertificateRefresher", mock.Anything, f.repo, + certRefreshTimeout, certRefreshBackoff).Once().Return(f.certRefresher) +} + +type mockForStartConfig struct { + getCertsErr error + refresherStartErr error +} + +func TestLocalScannerTLSIssuerStartStopSuccess(t *testing.T) { + testCases := map[string]struct { + getCertsErr error + }{ + "no error": {getCertsErr: nil}, + "missing secret data": {getCertsErr: errors.Wrap(ErrMissingSecretData, "wrap error")}, + "inconsistent CAs": {getCertsErr: errors.Wrap(ErrDifferentCAForDifferentServiceTypes, "wrap error")}, + "missing secret": {getCertsErr: k8sErrors.NewNotFound(schema.GroupResource{Group: "Core", Resource: "Secret"}, "scanner-db-slim-tls")}, + } + for tcName, tc := range testCases { + t.Run(tcName, func(t *testing.T) { + fixture := newLocalScannerTLSIssuerFixture(fakeK8sClientConfig{}) + fixture.mockForStart(mockForStartConfig{getCertsErr: tc.getCertsErr}) + fixture.certRefresher.On("Stop").Once() + fixture.certRequester.On("Stop").Once() + + startErr := fixture.tlsIssuer.Start() + fixture.tlsIssuer.Stop(nil) + + assert.NoError(t, startErr) + assert.Nil(t, fixture.tlsIssuer.certRefresher) + fixture.assertMockExpectations(t) + }) + } +} + +func TestLocalScannerTLSIssuerRefresherFailureStartFailure(t *testing.T) { + fixture := newLocalScannerTLSIssuerFixture(fakeK8sClientConfig{}) + fixture.mockForStart(mockForStartConfig{refresherStartErr: errForced}) + fixture.certRefresher.On("Stop").Once() + fixture.certRequester.On("Stop").Once() + + startErr := fixture.tlsIssuer.Start() + + assert.NoError(t, startErr) + fixture.assertMockExpectations(t) +} + +func TestLocalScannerTLSIssuerStartAlreadyStartedFailure(t *testing.T) { + fixture := newLocalScannerTLSIssuerFixture(fakeK8sClientConfig{}) + fixture.mockForStart(mockForStartConfig{}) + fixture.certRefresher.On("Stop").Once() + fixture.certRequester.On("Stop").Once() + + startErr := fixture.tlsIssuer.Start() + secondStartErr := fixture.tlsIssuer.Start() + + assert.NoError(t, startErr) + assert.NoError(t, secondStartErr) + fixture.assertMockExpectations(t) +} + +func TestLocalScannerTLSIssuerFetchSensorDeploymentOwnerRefErrorStartFailure(t *testing.T) { + testCases := map[string]struct { + k8sClientConfig fakeK8sClientConfig + }{ + "sensor replica set missing": {k8sClientConfig: fakeK8sClientConfig{skipSensorReplicaSet: true}}, + "sensor pod missing": {k8sClientConfig: fakeK8sClientConfig{skipSensorPod: true}}, + } + for tcName, tc := range testCases { + t.Run(tcName, func(t *testing.T) { + fixture := newLocalScannerTLSIssuerFixture(tc.k8sClientConfig) + fixture.certRefresher.On("Stop").Once() + fixture.certRequester.On("Stop").Once() + + startErr := fixture.tlsIssuer.Start() + + assert.NoError(t, startErr) + fixture.assertMockExpectations(t) + }) + } +} + +func TestLocalScannerTLSIssuerProcessMessageKnownMessage(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + processMessageDoneSignal := concurrency.NewErrorSignal() + fixture := newLocalScannerTLSIssuerFixture(fakeK8sClientConfig{}) + expectedResponse := ¢ral.IssueLocalScannerCertsResponse{ + RequestId: uuid.NewDummy().String(), + } + msg := ¢ral.MsgToSensor{ + Msg: ¢ral.MsgToSensor_IssueLocalScannerCertsResponse{ + IssueLocalScannerCertsResponse: expectedResponse, + }, + } + + go func() { + assert.NoError(t, fixture.tlsIssuer.ProcessMessage(msg)) + processMessageDoneSignal.Signal() + }() + + select { + case <-ctx.Done(): + assert.Fail(t, ctx.Err().Error()) + case response := <-fixture.tlsIssuer.msgFromCentralC: + assert.Equal(t, expectedResponse, response) + } + + _, ok := processMessageDoneSignal.WaitWithTimeout(100 * time.Millisecond) + assert.True(t, ok) +} + +func TestLocalScannerTLSIssuerProcessMessageUnknownMessage(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + processMessageDoneSignal := concurrency.NewErrorSignal() + fixture := newLocalScannerTLSIssuerFixture(fakeK8sClientConfig{}) + msg := ¢ral.MsgToSensor{ + Msg: ¢ral.MsgToSensor_ReprocessDeployments{}, + } + + go func() { + assert.NoError(t, fixture.tlsIssuer.ProcessMessage(msg)) + processMessageDoneSignal.Signal() + }() + + select { + case <-ctx.Done(): + case <-fixture.tlsIssuer.msgFromCentralC: + assert.Fail(t, "unknown message is not ignored") + } + _, ok := processMessageDoneSignal.WaitWithTimeout(100 * time.Millisecond) + assert.True(t, ok) +} + +func TestLocalScannerTLSIssuerIntegrationTests(t *testing.T) { + suite.Run(t, new(localScannerTLSIssueIntegrationTests)) +} + +type localScannerTLSIssueIntegrationTests struct { + suite.Suite + envIsolator *envisolator.EnvIsolator +} + +func (s *localScannerTLSIssueIntegrationTests) SetupSuite() { + s.envIsolator = envisolator.NewEnvIsolator(s.T()) +} + +func (s *localScannerTLSIssueIntegrationTests) SetupTest() { + err := testutilsMTLS.LoadTestMTLSCerts(s.envIsolator) + s.Require().NoError(err) +} + +func (s *localScannerTLSIssueIntegrationTests) TearDownTest() { + s.envIsolator.RestoreAll() +} + +func (s *localScannerTLSIssueIntegrationTests) TestSuccessfulRefresh() { + testCases := map[string]struct { + k8sClientConfig fakeK8sClientConfig + numFailedResponses int + }{ + "no secrets": {k8sClientConfig: fakeK8sClientConfig{}}, + "corrupted data in scanner secret": { + k8sClientConfig: fakeK8sClientConfig{ + secretsData: map[string]map[string][]byte{"scanner-slim-tls": nil}, + }, + }, + "corrupted data in scanner DB secret": { + k8sClientConfig: fakeK8sClientConfig{ + secretsData: map[string]map[string][]byte{"scanner-db-slim-tls": nil}, + }, + }, + "corrupted data in all local scanner secrets": { + k8sClientConfig: fakeK8sClientConfig{ + secretsData: map[string]map[string][]byte{"scanner-slim-tls": nil, "scanner-db-slim-tls": nil}, + }, + }, + "refresh failure and retries": {k8sClientConfig: fakeK8sClientConfig{}, numFailedResponses: 2}, + } + for tcName, tc := range testCases { + s.Run(tcName, func() { + testTimeout := 100 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + ca, err := mtls.CAForSigning() + s.Require().NoError(err) + scannerCert := s.getCertificate(storage.ServiceType_SCANNER_SERVICE) + scannerDBCert := s.getCertificate(storage.ServiceType_SCANNER_DB_SERVICE) + k8sClient := getFakeK8sClient(tc.k8sClientConfig) + tlsIssuer := newLocalScannerTLSIssuer(s.T(), k8sClient, sensorNamespace, sensorPodName) + tlsIssuer.certRefreshBackoff = wait.Backoff{ + Duration: time.Millisecond, + } + + s.Require().NoError(tlsIssuer.Start()) + defer tlsIssuer.Stop(nil) + s.Require().NotNil(tlsIssuer.certRefresher) + s.Require().False(tlsIssuer.certRefresher.Stopped()) + + for i := 0; i < tc.numFailedResponses; i++ { + request := s.waitForRequest(ctx, tlsIssuer) + response := getIssueCertsFailureResponse(request.GetRequestId()) + err = tlsIssuer.ProcessMessage(response) + s.Require().NoError(err) + } + + request := s.waitForRequest(ctx, tlsIssuer) + response := getIssueCertsSuccessResponse(request.GetRequestId(), ca.CertPEM(), scannerCert, scannerDBCert) + err = tlsIssuer.ProcessMessage(response) + s.Require().NoError(err) + + var secrets *v1.SecretList + ok := concurrency.PollWithTimeout(func() bool { + secrets, err = k8sClient.CoreV1().Secrets(sensorNamespace).List(context.Background(), metav1.ListOptions{}) + s.Require().NoError(err) + return len(secrets.Items) == 2 && len(secrets.Items[0].Data) > 0 && len(secrets.Items[1].Data) > 0 + }, 10*time.Millisecond, testTimeout) + s.Require().True(ok, "expected exactly 2 secrets with non-empty data available in the k8s API") + for _, secret := range secrets.Items { + var expectedCert *mtls.IssuedCert + switch secretName := secret.GetName(); secretName { + case "scanner-slim-tls": + expectedCert = scannerCert + case "scanner-db-slim-tls": + expectedCert = scannerDBCert + default: + s.Require().Failf("expected secret name should be either %q or %q, found %q instead", + "scanner-slim-tls", "scanner-db-slim-tls", secretName) + } + s.Equal(ca.CertPEM(), secret.Data[mtls.CACertFileName]) + s.Equal(expectedCert.CertPEM, secret.Data[mtls.ServiceCertFileName]) + s.Equal(expectedCert.KeyPEM, secret.Data[mtls.ServiceKeyFileName]) + } + }) + } +} + +func (s *localScannerTLSIssueIntegrationTests) TestUnexpectedOwnerStop() { + testCases := map[string]struct { + secretNames []string + }{ + "wrong owner for scanner secret": {secretNames: []string{"scanner-slim-tls"}}, + "wrong owner for scanner db secret": {secretNames: []string{"scanner-db-slim-tls"}}, + "wrong owner for scanner and scanner db secrets": {secretNames: []string{"scanner-slim-tls", "scanner-db-slim-tls"}}, + } + for tcName, tc := range testCases { + s.Run(tcName, func() { + secretsData := make(map[string]map[string][]byte, len(tc.secretNames)) + for _, secretName := range tc.secretNames { + secretsData[secretName] = nil + } + k8sClient := getFakeK8sClient(fakeK8sClientConfig{ + secretsData: secretsData, + secretsOwner: &metav1.OwnerReference{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "another-deployment", + UID: types.UID(uuid.NewDummy().String()), + }, + }) + tlsIssuer := newLocalScannerTLSIssuer(s.T(), k8sClient, sensorNamespace, sensorPodName) + + s.Require().NoError(tlsIssuer.Start()) + defer tlsIssuer.Stop(nil) + + ok := concurrency.PollWithTimeout(func() bool { + return tlsIssuer.certRefresher != nil && tlsIssuer.certRefresher.Stopped() + }, 10*time.Millisecond, 100*time.Millisecond) + s.True(ok, "cert refresher should be stopped") + }) + } +} + +func (s *localScannerTLSIssueIntegrationTests) getCertificate(serviceType storage.ServiceType) *mtls.IssuedCert { + // TODO(ROX-9463): use short expiration for testing renewal when ROX-9010 implementing `WithCustomCertLifetime` is merged + cert, err := issueCertificate(serviceType, mtls.WithValidityExpiringInHours()) + s.Require().NoError(err) + return cert +} + +func (s *localScannerTLSIssueIntegrationTests) waitForRequest(ctx context.Context, tlsIssuer common.SensorComponent) *central.IssueLocalScannerCertsRequest { + var request *central.MsgFromSensor + select { + case request = <-tlsIssuer.ResponsesC(): + case <-ctx.Done(): + s.Require().Fail(ctx.Err().Error()) + } + s.Require().NotNil(request.GetIssueLocalScannerCertsRequest()) + + return request.GetIssueLocalScannerCertsRequest() +} + +func getIssueCertsSuccessResponse(requestID string, caPem []byte, scannerCert, scannerDBCert *mtls.IssuedCert) *central.MsgToSensor { + return ¢ral.MsgToSensor{ + Msg: ¢ral.MsgToSensor_IssueLocalScannerCertsResponse{ + IssueLocalScannerCertsResponse: ¢ral.IssueLocalScannerCertsResponse{ + RequestId: requestID, + Response: ¢ral.IssueLocalScannerCertsResponse_Certificates{ + Certificates: &storage.TypedServiceCertificateSet{ + CaPem: caPem, + ServiceCerts: []*storage.TypedServiceCertificate{ + { + ServiceType: storage.ServiceType_SCANNER_SERVICE, + Cert: &storage.ServiceCertificate{ + KeyPem: scannerCert.KeyPEM, + CertPem: scannerCert.CertPEM, + }, + }, + { + ServiceType: storage.ServiceType_SCANNER_DB_SERVICE, + Cert: &storage.ServiceCertificate{ + KeyPem: scannerDBCert.KeyPEM, + CertPem: scannerDBCert.CertPEM, + }, + }, + }, + }, + }, + }, + }, + } +} + +func getIssueCertsFailureResponse(requestID string) *central.MsgToSensor { + return ¢ral.MsgToSensor{ + Msg: ¢ral.MsgToSensor_IssueLocalScannerCertsResponse{ + IssueLocalScannerCertsResponse: ¢ral.IssueLocalScannerCertsResponse{ + RequestId: requestID, + Response: ¢ral.IssueLocalScannerCertsResponse_Error{ + Error: ¢ral.LocalScannerCertsIssueError{ + Message: "forced error", + }, + }, + }, + }, + } +} + +func getFakeK8sClient(conf fakeK8sClientConfig) *fake.Clientset { + objects := make([]runtime.Object, 0) + if !conf.skipSensorReplicaSet { + sensorDeploymentGVK := sensorDeployment.GroupVersionKind() + sensorReplicaSet := &appsApiv1.ReplicaSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: sensorReplicasetName, + Namespace: sensorNamespace, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: sensorDeploymentGVK.GroupVersion().String(), + Kind: sensorDeploymentGVK.Kind, + Name: sensorDeployment.GetName(), + UID: sensorDeployment.GetUID(), + }, + }, + }, + } + objects = append(objects, sensorReplicaSet) + + sensorReplicaSetGVK := sensorReplicaSet.GroupVersionKind() + sensorReplicaSetOwnerRef := metav1.OwnerReference{ + APIVersion: sensorReplicaSetGVK.GroupVersion().String(), + Kind: sensorReplicaSet.Kind, + Name: sensorReplicaSet.GetName(), + UID: sensorReplicaSet.GetUID(), + } + + if !conf.skipSensorPod { + sensorPod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: sensorPodName, + Namespace: sensorNamespace, + OwnerReferences: []metav1.OwnerReference{sensorReplicaSetOwnerRef}, + }, + } + objects = append(objects, sensorPod) + } + + secretsOwnerRef := sensorReplicaSetOwnerRef + if conf.secretsOwner != nil { + secretsOwnerRef = *conf.secretsOwner + } + for secretName, secretData := range conf.secretsData { + secret := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: sensorNamespace, + OwnerReferences: []metav1.OwnerReference{secretsOwnerRef}, + }, + Data: secretData, + } + objects = append(objects, secret) + } + } + + k8sClient := fake.NewSimpleClientset(objects...) + + return k8sClient +} + +type fakeK8sClientConfig struct { + // if true then no sensor replica set and no sensor pod will be added to the test client. + skipSensorReplicaSet bool + // if true then no sensor pod set will be added to the test client. + skipSensorPod bool + // if skipSensorReplicaSet is false, then a secret will be added to the test client for + // each entry in this map, using the key as the secret name and the value as the secret data. + secretsData map[string]map[string][]byte + // owner reference to used for the secrets specified in `secretsData`. If `nil` then the sensor + // replica set is used as owner + secretsOwner *metav1.OwnerReference +} + +func newLocalScannerTLSIssuer( + t *testing.T, + k8sClient kubernetes.Interface, + sensorNamespace string, + sensorPodName string, +) *localScannerTLSIssuerImpl { + tlsIssuer := NewLocalScannerTLSIssuer(k8sClient, sensorNamespace, sensorPodName) + require.IsType(t, &localScannerTLSIssuerImpl{}, tlsIssuer) + return tlsIssuer.(*localScannerTLSIssuerImpl) +} + +type certificateRequesterMock struct { + mock.Mock +} + +func (m *certificateRequesterMock) Start() { + m.Called() +} +func (m *certificateRequesterMock) Stop() { + m.Called() +} +func (m *certificateRequesterMock) RequestCertificates(ctx context.Context) (*central.IssueLocalScannerCertsResponse, error) { + args := m.Called(ctx) + return args.Get(0).(*central.IssueLocalScannerCertsResponse), args.Error(1) +} + +type certificateRefresherMock struct { + mock.Mock + stopped bool +} + +func (m *certificateRefresherMock) Start() error { + args := m.Called() + return args.Error(0) +} + +func (m *certificateRefresherMock) Stop() { + m.Called() + m.stopped = true +} + +func (m *certificateRefresherMock) Stopped() bool { + return m.stopped +} + +type componentGetterMock struct { + mock.Mock +} + +func (m *componentGetterMock) getCertificateRefresher(requestCertificates requestCertificatesFunc, + repository serviceCertificatesRepo, timeout time.Duration, backoff wait.Backoff) concurrency.RetryTicker { + args := m.Called(requestCertificates, repository, timeout, backoff) + return args.Get(0).(concurrency.RetryTicker) +} + +func (m *componentGetterMock) getServiceCertificatesRepo(ownerReference metav1.OwnerReference, namespace string, + secretsClient corev1.SecretInterface) serviceCertificatesRepo { + args := m.Called(ownerReference, namespace, secretsClient) + return args.Get(0).(serviceCertificatesRepo) +} + +type certsRepoMock struct { + mock.Mock +} + +func (m *certsRepoMock) getServiceCertificates(ctx context.Context) (*storage.TypedServiceCertificateSet, error) { + args := m.Called(ctx) + return args.Get(0).(*storage.TypedServiceCertificateSet), args.Error(1) +} + +func (m *certsRepoMock) ensureServiceCertificates(ctx context.Context, certificates *storage.TypedServiceCertificateSet) error { + args := m.Called(ctx, certificates) + return args.Error(0) +} diff --git a/sensor/kubernetes/sensor/sensor.go b/sensor/kubernetes/sensor/sensor.go index c54609364aee3..d15b7ca88b6f3 100644 --- a/sensor/kubernetes/sensor/sensor.go +++ b/sensor/kubernetes/sensor/sensor.go @@ -45,6 +45,7 @@ import ( "github.com/stackrox/rox/sensor/kubernetes/fake" "github.com/stackrox/rox/sensor/kubernetes/listener" "github.com/stackrox/rox/sensor/kubernetes/listener/resources" + "github.com/stackrox/rox/sensor/kubernetes/localscanner" "github.com/stackrox/rox/sensor/kubernetes/networkpolicies" "github.com/stackrox/rox/sensor/kubernetes/orchestrator" "github.com/stackrox/rox/sensor/kubernetes/telemetry" @@ -162,6 +163,13 @@ func CreateSensor(client client.Interface, workloadHandler *fake.WorkloadManager return nil, errors.Wrap(err, "creating central client") } + if features.LocalImageScanning.Enabled() && (helmManagedConfig.GetManagedBy() != storage.ManagerType_MANAGER_TYPE_UNKNOWN && + helmManagedConfig.GetManagedBy() != storage.ManagerType_MANAGER_TYPE_MANUAL) { + podName := os.Getenv("POD_NAME") + components = append(components, + localscanner.NewLocalScannerTLSIssuer(client.Kubernetes(), sensorNamespace, podName)) + } + s := sensor.NewSensor( configHandler, policyDetector,