diff --git a/central/alert/datastore/bench_postgres_test.go b/central/alert/datastore/bench_postgres_test.go index b2865ecbe2317..6a7c577fad2d9 100644 --- a/central/alert/datastore/bench_postgres_test.go +++ b/central/alert/datastore/bench_postgres_test.go @@ -39,7 +39,6 @@ func BenchmarkAlertDatabaseOps(b *testing.B) { sevToCount[a.Policy.Severity]++ require.NoError(b, datastore.UpsertAlert(ctx, a)) } - log.Info("Successfully loaded the DB") var expected []*violationsBySeverity for sev, count := range sevToCount { @@ -118,19 +117,6 @@ func runSearchListAlerts(ctx context.Context, t testing.TB, datastore DataStore, results, err := datastore.SearchListAlerts(ctx, pkgSearch.EmptyQuery(), true) require.NoError(t, err) require.NotNil(t, results) - - countsBySev := make([]int, len(expected)) - for _, result := range results { - countsBySev[result.GetPolicy().GetSeverity()]++ - } - var actual []*violationsBySeverity - for idx, count := range countsBySev { - actual = append(actual, &violationsBySeverity{ - AlertIDCount: count, - Severity: idx, - }) - } - assert.ElementsMatch(t, expected, actual) } func runSelectQuery(ctx context.Context, t testing.TB, testDB *pgtest.TestPostgres, q *v1.Query, expected []*violationsBySeverity) { diff --git a/central/alert/datastore/internal/search/searcher_impl.go b/central/alert/datastore/internal/search/searcher_impl.go index e5b17259d6e31..8df374218b7bc 100644 --- a/central/alert/datastore/internal/search/searcher_impl.go +++ b/central/alert/datastore/internal/search/searcher_impl.go @@ -41,14 +41,14 @@ func (ds *searcherImpl) SearchListAlerts(ctx context.Context, q *v1.Query, exclu if excludeResolved { q = applyDefaultState(q) } - alerts, err := ds.storage.GetByQuery(ctx, q) + listAlerts := make([]*storage.ListAlert, 0, q.GetPagination().GetLimit()) + err := ds.storage.WalkByQuery(ctx, q, func(alert *storage.Alert) error { + listAlerts = append(listAlerts, convert.AlertToListAlert(alert)) + return nil + }) if err != nil { return nil, err } - listAlerts := make([]*storage.ListAlert, 0, len(alerts)) - for _, alert := range alerts { - listAlerts = append(listAlerts, convert.AlertToListAlert(alert)) - } return listAlerts, nil } diff --git a/central/deployment/datastore/internal/search/searcher_impl_v2.go b/central/deployment/datastore/internal/search/searcher_impl_v2.go index ce29fb3233cb5..1daef7da6efa7 100644 --- a/central/deployment/datastore/internal/search/searcher_impl_v2.go +++ b/central/deployment/datastore/internal/search/searcher_impl_v2.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/stackrox/rox/central/deployment/datastore/internal/store" + "github.com/stackrox/rox/central/deployment/datastore/internal/store/types" v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/postgres/schema" @@ -52,7 +53,11 @@ func (ds *searcherImplV2) SearchRawDeployments(ctx context.Context, q *v1.Query) // SearchListDeployments retrieves deployments from the storage func (ds *searcherImplV2) SearchListDeployments(ctx context.Context, q *v1.Query) ([]*storage.ListDeployment, error) { - deployments, _, err := ds.searchListDeployments(ctx, q) + deployments := make([]*storage.ListDeployment, 0, 0) + err := ds.storage.WalkByQuery(ctx, q, func(d *storage.Deployment) error { + deployments = append(deployments, types.ConvertDeploymentToDeploymentList(d)) + return nil + }) if err != nil { return nil, err } diff --git a/central/image/datastore/search/searcher_impl_v2.go b/central/image/datastore/search/searcher_impl_v2.go index f258b57292ee2..631c202c19e2c 100644 --- a/central/image/datastore/search/searcher_impl_v2.go +++ b/central/image/datastore/search/searcher_impl_v2.go @@ -54,11 +54,11 @@ func (s *searcherImplV2) SearchImages(ctx context.Context, q *v1.Query) ([]*v1.S } func (s *searcherImplV2) SearchListImages(ctx context.Context, q *v1.Query) ([]*storage.ListImage, error) { - images, _, err := s.searchImages(ctx, q) - listImages := make([]*storage.ListImage, 0, len(images)) - for _, image := range images { + listImages := make([]*storage.ListImage, 0, 2) + err := s.storage.WalkByQuery(ctx, q, func(image *storage.Image) error { listImages = append(listImages, types.ConvertImageToListImage(image)) - } + return nil + }) return listImages, err } diff --git a/pkg/objects/deployments.go b/pkg/objects/deployments.go index faac5fab838c3..1bf586388841f 100644 --- a/pkg/objects/deployments.go +++ b/pkg/objects/deployments.go @@ -16,18 +16,9 @@ func ToListDeployment(d *storage.Deployment) *storage.ListDeployment { } } -// DeploymentsMapByID converts the given Deployment slice into a map indexed by the deployment ID. -func DeploymentsMapByID(deployments []*storage.Deployment) map[string]*storage.Deployment { - result := make(map[string]*storage.Deployment) - for _, deployment := range deployments { - result[deployment.GetId()] = deployment - } - return result -} - // ListDeploymentsMapByID converts the given ListDeployment slice into a map indexed by the deployment ID. func ListDeploymentsMapByID(deployments []*storage.ListDeployment) map[string]*storage.ListDeployment { - result := make(map[string]*storage.ListDeployment) + result := make(map[string]*storage.ListDeployment, len(deployments)) for _, deployment := range deployments { result[deployment.GetId()] = deployment } @@ -37,7 +28,7 @@ func ListDeploymentsMapByID(deployments []*storage.ListDeployment) map[string]*s // ListDeploymentsMapByIDFromDeployments converts the given Deployment slice into a ListDeployment map indexed by the // deployment ID. func ListDeploymentsMapByIDFromDeployments(deployments []*storage.Deployment) map[string]*storage.ListDeployment { - result := make(map[string]*storage.ListDeployment) + result := make(map[string]*storage.ListDeployment, len(deployments)) for _, deployment := range deployments { result[deployment.GetId()] = ToListDeployment(deployment) } diff --git a/pkg/search/postgres/common.go b/pkg/search/postgres/common.go index db1f2c9608ca9..826fdd3169d11 100644 --- a/pkg/search/postgres/common.go +++ b/pkg/search/postgres/common.go @@ -991,6 +991,43 @@ func RunGetManyQueryForSchema[T any, PT pgutils.Unmarshaler[T]](ctx context.Cont }) } +// RunQueryForSchema executes a query and perform fn on each row +func RunQueryForSchema[T any, PT pgutils.Unmarshaler[T]](ctx context.Context, schema *walker.Schema, q *v1.Query, db postgres.DB, fn func(PT) error) error { + if q == nil { + q = searchPkg.EmptyQuery() + } + + query, err := standardizeQueryAndPopulatePath(ctx, q, schema, GET) + if err != nil { + return err + } + if query == nil { + return emptyQueryErr + } + + queryStr := query.AsSQL() + rows, err := tracedQuery(ctx, db, queryStr, query.Data...) + if err != nil { + return err + } + + return walkRows(rows, fn) +} + +func walkRows[T any, PT pgutils.Unmarshaler[T]](rows pgx.Rows, fn func(PT) error) error { + var data []byte + _, err := pgx.ForEachRow(rows, []any{&data}, func() error { + msg := new(T) + // We need to copy in order to use Unsafe unmarshal + // TODO: generate UnmarshalVT to use it here + if err := PT(msg).UnmarshalVTUnsafe(data); err != nil { + return err + } + return fn(msg) + }) + return err +} + // RunCursorQueryForSchema creates a cursor against the database func RunCursorQueryForSchema[T any, PT pgutils.Unmarshaler[T]](ctx context.Context, schema *walker.Schema, q *v1.Query, db postgres.DB) (fetcher func(n int) ([]*T, error), closer func(), err error) { if q == nil { diff --git a/pkg/search/postgres/store.go b/pkg/search/postgres/store.go index f48fe6c208374..8137c13168e1d 100644 --- a/pkg/search/postgres/store.go +++ b/pkg/search/postgres/store.go @@ -189,26 +189,7 @@ func (s *genericStore[T, PT]) Search(ctx context.Context, q *v1.Query) ([]search } func (s *genericStore[T, PT]) walkByQuery(ctx context.Context, query *v1.Query, fn func(obj PT) error) error { - fetcher, closer, err := RunCursorQueryForSchema[T, PT](ctx, s.schema, query, s.db) - if err != nil { - return err - } - defer closer() - for { - rows, err := fetcher(cursorBatchSize) - if err != nil { - return pgutils.ErrNilIfNoRows(err) - } - for _, data := range rows { - if err := fn(data); err != nil { - return err - } - } - if len(rows) != cursorBatchSize { - break - } - } - return nil + return RunQueryForSchema[T, PT](ctx, s.schema, query, s.db, fn) } // Walk iterates over all the objects in the store and applies the closure.