From 2078c690f5bc091a9c1f6f9a2b3a185d9e5e8282 Mon Sep 17 00:00:00 2001 From: Tomasz Janiszewski Date: Tue, 23 Sep 2025 16:27:30 +0200 Subject: [PATCH] fix(risk): improve reprocess deployment risk Signed-off-by: Tomasz Janiszewski --- .../evaluator/evaluator_impl.go | 24 ++++++++++--------- .../evaluator/evaluator_test.go | 10 +++++++- .../processindicator/datastore/datastore.go | 1 + .../datastore/datastore_impl.go | 4 ++++ .../datastore/mocks/datastore.go | 14 +++++++++++ central/processindicator/store/mocks/store.go | 14 +++++++++++ central/processindicator/store/store.go | 1 + 7 files changed, 56 insertions(+), 12 deletions(-) diff --git a/central/processbaseline/evaluator/evaluator_impl.go b/central/processbaseline/evaluator/evaluator_impl.go index a0f65188c2dcf..8543c9aecd595 100644 --- a/central/processbaseline/evaluator/evaluator_impl.go +++ b/central/processbaseline/evaluator/evaluator_impl.go @@ -86,30 +86,32 @@ func (e *evaluator) EvaluateBaselinesAndPersistResult(deployment *storage.Deploy } } - var processes []*storage.ProcessIndicator - if hasAtLeastOneLockedBaseline { - processes, err = e.indicators.SearchRawProcessIndicators(evaluatorCtx, search.NewQueryBuilder().AddExactMatches(search.DeploymentID, deployment.GetId()).ProtoQuery()) - if err != nil { - return nil, errors.Wrapf(err, "searching process indicators for deployment %s/%s/%s", deployment.GetClusterName(), deployment.GetNamespace(), deployment.GetName()) - } - } - for _, process := range processes { + fn := func(process *storage.ProcessIndicator) error { processSet, exists := containerNameToBaselinedProcesses[process.GetContainerName()] // If no explicit baseline, then all processes are valid. if !exists { - continue + return nil } baselineItem := processBaselinePkg.BaselineItemFromProcess(process) if baselineItem == "" { - continue + return nil } if processbaseline.IsStartupProcess(process) { - continue + return nil } if !processSet.Contains(processBaselinePkg.BaselineItemFromProcess(process)) { violatingProcesses = append(violatingProcesses, process) containerNameToBaselineResults[process.GetContainerName()].AnomalousProcessesExecuted = true } + return nil + } + + if hasAtLeastOneLockedBaseline { + query := search.NewQueryBuilder().AddExactMatches(search.DeploymentID, deployment.GetId()).ProtoQuery() + err := e.indicators.GetByQueryFn(evaluatorCtx, query, fn) + if err != nil { + return nil, errors.Wrapf(err, "searching process indicators for deployment %s/%s/%s", deployment.GetClusterName(), deployment.GetNamespace(), deployment.GetName()) + } } baselineResults, err := e.baselineResults.GetBaselineResults(evaluatorCtx, deployment.GetId()) diff --git a/central/processbaseline/evaluator/evaluator_test.go b/central/processbaseline/evaluator/evaluator_test.go index eec29c8f9e9ba..1331777ec9840 100644 --- a/central/processbaseline/evaluator/evaluator_test.go +++ b/central/processbaseline/evaluator/evaluator_test.go @@ -1,6 +1,7 @@ package evaluator import ( + "context" "strings" "testing" "time" @@ -8,6 +9,7 @@ import ( processBaselineMocks "github.com/stackrox/rox/central/processbaseline/datastore/mocks" processBaselineResultMocks "github.com/stackrox/rox/central/processbaselineresults/datastore/mocks" processIndicatorMocks "github.com/stackrox/rox/central/processindicator/datastore/mocks" + v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/fixtures" "github.com/stackrox/rox/pkg/protoassert" @@ -269,7 +271,13 @@ func TestProcessBaselineEvaluator(t *testing.T) { mockBaselines.EXPECT().GetProcessBaseline(gomock.Any(), gomock.Any()).MaxTimes(len(deployment.GetContainers())).Return(c.baseline, c.baseline != nil, c.baselineErr) if c.indicators != nil { - mockIndicators.EXPECT().SearchRawProcessIndicators(gomock.Any(), gomock.Any()).Return(c.indicators, c.indicatorErr) + mockIndicators.EXPECT().GetByQueryFn(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(_ context.Context, _ *v1.Query, fn func(indicator *storage.ProcessIndicator) error) { + for _, i := range c.indicators { + err := fn(i) + require.NoError(t, err) + } + }).Return(c.indicatorErr) } expectedBaselineResult := &storage.ProcessBaselineResults{ diff --git a/central/processindicator/datastore/datastore.go b/central/processindicator/datastore/datastore.go index 31e24cf305257..f5ecc97c100c7 100644 --- a/central/processindicator/datastore/datastore.go +++ b/central/processindicator/datastore/datastore.go @@ -26,6 +26,7 @@ type DataStore interface { Search(ctx context.Context, q *v1.Query) ([]pkgSearch.Result, error) SearchRawProcessIndicators(ctx context.Context, q *v1.Query) ([]*storage.ProcessIndicator, error) + GetByQueryFn(ctx context.Context, query *v1.Query, fn func(obj *storage.ProcessIndicator) error) error GetProcessIndicator(ctx context.Context, id string) (*storage.ProcessIndicator, bool, error) GetProcessIndicators(ctx context.Context, ids []string) ([]*storage.ProcessIndicator, bool, error) diff --git a/central/processindicator/datastore/datastore_impl.go b/central/processindicator/datastore/datastore_impl.go index e8e11fbf39d1b..7b6f2ad943116 100644 --- a/central/processindicator/datastore/datastore_impl.go +++ b/central/processindicator/datastore/datastore_impl.go @@ -58,6 +58,10 @@ func (ds *datastoreImpl) SearchRawProcessIndicators(ctx context.Context, q *v1.Q return ds.storage.GetByQuery(ctx, q) } +func (ds *datastoreImpl) GetByQueryFn(ctx context.Context, query *v1.Query, fn func(obj *storage.ProcessIndicator) error) error { + return ds.storage.GetByQueryFn(ctx, query, fn) +} + func (ds *datastoreImpl) GetProcessIndicator(ctx context.Context, id string) (*storage.ProcessIndicator, bool, error) { indicator, exists, err := ds.storage.Get(ctx, id) if err != nil || !exists { diff --git a/central/processindicator/datastore/mocks/datastore.go b/central/processindicator/datastore/mocks/datastore.go index b0526a7557122..8ddba3fcdf2ff 100644 --- a/central/processindicator/datastore/mocks/datastore.go +++ b/central/processindicator/datastore/mocks/datastore.go @@ -78,6 +78,20 @@ func (mr *MockDataStoreMockRecorder) Count(ctx, q any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockDataStore)(nil).Count), ctx, q) } +// GetByQueryFn mocks base method. +func (m *MockDataStore) GetByQueryFn(ctx context.Context, query *v1.Query, fn func(*storage.ProcessIndicator) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByQueryFn", ctx, query, fn) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetByQueryFn indicates an expected call of GetByQueryFn. +func (mr *MockDataStoreMockRecorder) GetByQueryFn(ctx, query, fn any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByQueryFn", reflect.TypeOf((*MockDataStore)(nil).GetByQueryFn), ctx, query, fn) +} + // GetProcessIndicator mocks base method. func (m *MockDataStore) GetProcessIndicator(ctx context.Context, id string) (*storage.ProcessIndicator, bool, error) { m.ctrl.T.Helper() diff --git a/central/processindicator/store/mocks/store.go b/central/processindicator/store/mocks/store.go index 4a93fcdd2e8d4..998b4e456bc70 100644 --- a/central/processindicator/store/mocks/store.go +++ b/central/processindicator/store/mocks/store.go @@ -117,6 +117,20 @@ func (mr *MockStoreMockRecorder) GetByQuery(ctx, q any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByQuery", reflect.TypeOf((*MockStore)(nil).GetByQuery), ctx, q) } +// GetByQueryFn mocks base method. +func (m *MockStore) GetByQueryFn(ctx context.Context, query *v1.Query, fn func(*storage.ProcessIndicator) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByQueryFn", ctx, query, fn) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetByQueryFn indicates an expected call of GetByQueryFn. +func (mr *MockStoreMockRecorder) GetByQueryFn(ctx, query, fn any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByQueryFn", reflect.TypeOf((*MockStore)(nil).GetByQueryFn), ctx, query, fn) +} + // GetMany mocks base method. func (m *MockStore) GetMany(ctx context.Context, ids []string) ([]*storage.ProcessIndicator, []int, error) { m.ctrl.T.Helper() diff --git a/central/processindicator/store/store.go b/central/processindicator/store/store.go index e84ad6597e44f..44e36f920841a 100644 --- a/central/processindicator/store/store.go +++ b/central/processindicator/store/store.go @@ -17,6 +17,7 @@ type Store interface { Get(ctx context.Context, id string) (*storage.ProcessIndicator, bool, error) GetByQuery(ctx context.Context, q *v1.Query) ([]*storage.ProcessIndicator, error) + GetByQueryFn(ctx context.Context, query *v1.Query, fn func(obj *storage.ProcessIndicator) error) error GetMany(ctx context.Context, ids []string) ([]*storage.ProcessIndicator, []int, error) UpsertMany(context.Context, []*storage.ProcessIndicator) error