From 7edd1aa0b7b6b120be1905718e5edb01da314f7d Mon Sep 17 00:00:00 2001 From: Connor Gorman Date: Mon, 13 Mar 2023 21:41:02 -0700 Subject: [PATCH 01/40] Add Blob store to Postgres --- central/blob/datastore/datastore.go | 36 ++ central/blob/datastore/singleton.go | 21 + central/blob/datastore/store/postgres/gen.go | 3 + .../blob/datastore/store/postgres/store.go | 536 +++++++++++++++++ .../datastore/store/postgres/store_test.go | 118 ++++ central/blob/datastore/store/store.go | 148 +++++ central/blob/datastore/store/store_test.go | 74 +++ generated/storage/blob.pb.go | 557 ++++++++++++++++++ pkg/postgres/conn.go | 39 +- pkg/postgres/context.go | 19 + pkg/postgres/pool.go | 31 +- pkg/postgres/schema/blobs.go | 40 ++ .../pg-table-bindings/list.go | 1 + 13 files changed, 1617 insertions(+), 6 deletions(-) create mode 100644 central/blob/datastore/datastore.go create mode 100644 central/blob/datastore/singleton.go create mode 100644 central/blob/datastore/store/postgres/gen.go create mode 100644 central/blob/datastore/store/postgres/store.go create mode 100644 central/blob/datastore/store/postgres/store_test.go create mode 100644 central/blob/datastore/store/store.go create mode 100644 central/blob/datastore/store/store_test.go create mode 100644 generated/storage/blob.pb.go create mode 100644 pkg/postgres/context.go create mode 100644 pkg/postgres/schema/blobs.go diff --git a/central/blob/datastore/datastore.go b/central/blob/datastore/datastore.go new file mode 100644 index 0000000000000..6c8f8360d5d1a --- /dev/null +++ b/central/blob/datastore/datastore.go @@ -0,0 +1,36 @@ +package datastore + +import ( + "context" + "io" + + "github.com/stackrox/rox/central/blob/datastore/store" + "github.com/stackrox/rox/generated/storage" +) + +// Datastore provides access to the blob store +type Datastore interface { + Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error + Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) +} + +// NewDatastore creates a new Blob datastore +func NewDatastore(store store.Store) Datastore { + return &datastoreImpl{ + store: store, + } +} + +type datastoreImpl struct { + store store.Store +} + +// Upsert adds a new blob to the database +func (d *datastoreImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error { + return d.store.Upsert(ctx, obj, reader) +} + +// Get retrieves a blob from the database +func (d *datastoreImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { + return d.store.Get(ctx, name, writer) +} diff --git a/central/blob/datastore/singleton.go b/central/blob/datastore/singleton.go new file mode 100644 index 0000000000000..171040d139951 --- /dev/null +++ b/central/blob/datastore/singleton.go @@ -0,0 +1,21 @@ +package datastore + +import ( + "github.com/stackrox/rox/central/blob/datastore/store" + "github.com/stackrox/rox/central/globaldb" + "github.com/stackrox/rox/pkg/sync" +) + +var ( + once sync.Once + + ds Datastore +) + +// Singleton returns the blob datastore +func Singleton() Datastore { + once.Do(func() { + ds = NewDatastore(store.New(globaldb.GetPostgres())) + }) + return ds +} diff --git a/central/blob/datastore/store/postgres/gen.go b/central/blob/datastore/store/postgres/gen.go new file mode 100644 index 0000000000000..47d6b48ec8853 --- /dev/null +++ b/central/blob/datastore/store/postgres/gen.go @@ -0,0 +1,3 @@ +package postgres + +//go:generate pg-table-bindings-wrapper --type=storage.Blob diff --git a/central/blob/datastore/store/postgres/store.go b/central/blob/datastore/store/postgres/store.go new file mode 100644 index 0000000000000..421213e2dd50d --- /dev/null +++ b/central/blob/datastore/store/postgres/store.go @@ -0,0 +1,536 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. + +package postgres + +import ( + "context" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/jackc/pgx/v4" + "github.com/pkg/errors" + "github.com/stackrox/rox/central/metrics" + "github.com/stackrox/rox/central/role/resources" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/logging" + ops "github.com/stackrox/rox/pkg/metrics" + "github.com/stackrox/rox/pkg/postgres" + "github.com/stackrox/rox/pkg/postgres/pgutils" + pkgSchema "github.com/stackrox/rox/pkg/postgres/schema" + "github.com/stackrox/rox/pkg/sac" + "github.com/stackrox/rox/pkg/search" + pgSearch "github.com/stackrox/rox/pkg/search/postgres" + "github.com/stackrox/rox/pkg/sync" + "gorm.io/gorm" +) + +const ( + baseTable = "blobs" + + batchAfter = 100 + + // using copyFrom, we may not even want to batch. It would probably be simpler + // to deal with failures if we just sent it all. Something to think about as we + // proceed and move into more e2e and larger performance testing + batchSize = 10000 + + cursorBatchSize = 50 + deleteBatchSize = 5000 +) + +var ( + log = logging.LoggerForModule() + schema = pkgSchema.BlobsSchema + targetResource = resources.Administration +) + +// Store is the interface to interact with the storage for storage.Blob +type Store interface { + Upsert(ctx context.Context, obj *storage.Blob) error + UpsertMany(ctx context.Context, objs []*storage.Blob) error + Delete(ctx context.Context, name string) error + DeleteByQuery(ctx context.Context, q *v1.Query) error + DeleteMany(ctx context.Context, identifiers []string) error + + Count(ctx context.Context) (int, error) + Exists(ctx context.Context, name string) (bool, error) + + Get(ctx context.Context, name string) (*storage.Blob, bool, error) + GetMany(ctx context.Context, identifiers []string) ([]*storage.Blob, []int, error) + GetIDs(ctx context.Context) ([]string, error) + + Walk(ctx context.Context, fn func(obj *storage.Blob) error) error + + AckKeysIndexed(ctx context.Context, keys ...string) error + GetKeysToIndex(ctx context.Context) ([]string, error) +} + +type storeImpl struct { + db *postgres.DB + mutex sync.RWMutex +} + +// New returns a new Store instance using the provided sql instance. +func New(db *postgres.DB) Store { + return &storeImpl{ + db: db, + } +} + +//// Helper functions + +func insertIntoBlobs(ctx context.Context, batch *pgx.Batch, obj *storage.Blob) error { + + serialized, marshalErr := obj.Marshal() + if marshalErr != nil { + return marshalErr + } + + values := []interface{}{ + // parent primary keys start + obj.GetName(), + serialized, + } + + finalStr := "INSERT INTO blobs (Name, serialized) VALUES($1, $2) ON CONFLICT(Name) DO UPDATE SET Name = EXCLUDED.Name, serialized = EXCLUDED.serialized" + batch.Queue(finalStr, values...) + + return nil +} + +func (s *storeImpl) copyFromBlobs(ctx context.Context, tx *postgres.Tx, objs ...*storage.Blob) error { + + inputRows := [][]interface{}{} + + var err error + + // This is a copy so first we must delete the rows and re-add them + // Which is essentially the desired behaviour of an upsert. + var deletes []string + + copyCols := []string{ + + "name", + + "serialized", + } + + for idx, obj := range objs { + // Todo: ROX-9499 Figure out how to more cleanly template around this issue. + log.Debugf("This is here for now because there is an issue with pods_TerminatedInstances where the obj "+ + "in the loop is not used as it only consists of the parent ID and the index. Putting this here as a stop gap "+ + "to simply use the object. %s", obj) + + serialized, marshalErr := obj.Marshal() + if marshalErr != nil { + return marshalErr + } + + inputRows = append(inputRows, []interface{}{ + + obj.GetName(), + + serialized, + }) + + // Add the ID to be deleted. + deletes = append(deletes, obj.GetName()) + + // if we hit our batch size we need to push the data + if (idx+1)%batchSize == 0 || idx == len(objs)-1 { + // copy does not upsert so have to delete first. parent deletion cascades so only need to + // delete for the top level parent + + if err := s.DeleteMany(ctx, deletes); err != nil { + return err + } + // clear the inserts and vals for the next batch + deletes = nil + + _, err = tx.CopyFrom(ctx, pgx.Identifier{"blobs"}, copyCols, pgx.CopyFromRows(inputRows)) + + if err != nil { + return err + } + + // clear the input rows for the next batch + inputRows = inputRows[:0] + } + } + + return err +} + +func (s *storeImpl) acquireConn(ctx context.Context, op ops.Op, typ string) (*postgres.Conn, func(), error) { + defer metrics.SetAcquireDBConnDuration(time.Now(), op, typ) + conn, err := s.db.Acquire(ctx) + if err != nil { + return nil, nil, err + } + return conn, conn.Release, nil +} + +func (s *storeImpl) copyFrom(ctx context.Context, objs ...*storage.Blob) error { + conn, release, err := s.acquireConn(ctx, ops.Get, "Blob") + if err != nil { + return err + } + defer release() + + tx, err := conn.Begin(ctx) + if err != nil { + return err + } + + if err := s.copyFromBlobs(ctx, tx, objs...); err != nil { + if err := tx.Rollback(ctx); err != nil { + return err + } + return err + } + if err := tx.Commit(ctx); err != nil { + return err + } + return nil +} + +func (s *storeImpl) upsert(ctx context.Context, objs ...*storage.Blob) error { + conn, release, err := s.acquireConn(ctx, ops.Get, "Blob") + if err != nil { + return err + } + defer release() + + for _, obj := range objs { + batch := &pgx.Batch{} + if err := insertIntoBlobs(ctx, batch, obj); err != nil { + return err + } + batchResults := conn.SendBatch(ctx, batch) + var result *multierror.Error + for i := 0; i < batch.Len(); i++ { + _, err := batchResults.Exec() + result = multierror.Append(result, err) + } + if err := batchResults.Close(); err != nil { + return err + } + if err := result.ErrorOrNil(); err != nil { + return err + } + } + return nil +} + +//// Helper functions - END + +//// Interface functions + +// Upsert saves the current state of an object in storage. +func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob) error { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.Upsert, "Blob") + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_WRITE_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return sac.ErrResourceAccessDenied + } + + return pgutils.Retry(func() error { + return s.upsert(ctx, obj) + }) +} + +// UpsertMany saves the state of multiple objects in the storage. +func (s *storeImpl) UpsertMany(ctx context.Context, objs []*storage.Blob) error { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.UpdateMany, "Blob") + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_WRITE_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return sac.ErrResourceAccessDenied + } + + return pgutils.Retry(func() error { + // Lock since copyFrom requires a delete first before being executed. If multiple processes are updating + // same subset of rows, both deletes could occur before the copyFrom resulting in unique constraint + // violations + if len(objs) < batchAfter { + s.mutex.RLock() + defer s.mutex.RUnlock() + + return s.upsert(ctx, objs...) + } + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.copyFrom(ctx, objs...) + }) +} + +// Delete removes the object associated to the specified ID from the store. +func (s *storeImpl) Delete(ctx context.Context, name string) error { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.Remove, "Blob") + + var sacQueryFilter *v1.Query + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_WRITE_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return sac.ErrResourceAccessDenied + } + + q := search.ConjunctionQuery( + sacQueryFilter, + search.NewQueryBuilder().AddDocIDs(name).ProtoQuery(), + ) + + return pgSearch.RunDeleteRequestForSchema(ctx, schema, q, s.db) +} + +// DeleteByQuery removes the objects from the store based on the passed query. +func (s *storeImpl) DeleteByQuery(ctx context.Context, query *v1.Query) error { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.Remove, "Blob") + + var sacQueryFilter *v1.Query + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_WRITE_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return sac.ErrResourceAccessDenied + } + + q := search.ConjunctionQuery( + sacQueryFilter, + query, + ) + + return pgSearch.RunDeleteRequestForSchema(ctx, schema, q, s.db) +} + +// DeleteMany removes the objects associated to the specified IDs from the store. +func (s *storeImpl) DeleteMany(ctx context.Context, identifiers []string) error { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.RemoveMany, "Blob") + + var sacQueryFilter *v1.Query + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_WRITE_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return sac.ErrResourceAccessDenied + } + + // Batch the deletes + localBatchSize := deleteBatchSize + numRecordsToDelete := len(identifiers) + for { + if len(identifiers) == 0 { + break + } + + if len(identifiers) < localBatchSize { + localBatchSize = len(identifiers) + } + + identifierBatch := identifiers[:localBatchSize] + q := search.ConjunctionQuery( + sacQueryFilter, + search.NewQueryBuilder().AddDocIDs(identifierBatch...).ProtoQuery(), + ) + + if err := pgSearch.RunDeleteRequestForSchema(ctx, schema, q, s.db); err != nil { + return errors.Wrapf(err, "unable to delete the records. Successfully deleted %d out of %d", numRecordsToDelete-len(identifiers), numRecordsToDelete) + } + + // Move the slice forward to start the next batch + identifiers = identifiers[localBatchSize:] + } + + return nil +} + +// Count returns the number of objects in the store. +func (s *storeImpl) Count(ctx context.Context) (int, error) { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.Count, "Blob") + + var sacQueryFilter *v1.Query + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return 0, nil + } + + return pgSearch.RunCountRequestForSchema(ctx, schema, sacQueryFilter, s.db) +} + +// Exists returns if the ID exists in the store. +func (s *storeImpl) Exists(ctx context.Context, name string) (bool, error) { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.Exists, "Blob") + + var sacQueryFilter *v1.Query + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return false, nil + } + + q := search.ConjunctionQuery( + sacQueryFilter, + search.NewQueryBuilder().AddDocIDs(name).ProtoQuery(), + ) + + count, err := pgSearch.RunCountRequestForSchema(ctx, schema, q, s.db) + // With joins and multiple paths to the scoping resources, it can happen that the Count query for an object identifier + // returns more than 1, despite the fact that the identifier is unique in the table. + return count > 0, err +} + +// Get returns the object, if it exists from the store. +func (s *storeImpl) Get(ctx context.Context, name string) (*storage.Blob, bool, error) { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.Get, "Blob") + + var sacQueryFilter *v1.Query + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return nil, false, nil + } + + q := search.ConjunctionQuery( + sacQueryFilter, + search.NewQueryBuilder().AddDocIDs(name).ProtoQuery(), + ) + + data, err := pgSearch.RunGetQueryForSchema[storage.Blob](ctx, schema, q, s.db) + if err != nil { + return nil, false, pgutils.ErrNilIfNoRows(err) + } + + return data, true, nil +} + +// GetMany returns the objects specified by the IDs from the store as well as the index in the missing indices slice. +func (s *storeImpl) GetMany(ctx context.Context, identifiers []string) ([]*storage.Blob, []int, error) { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.GetMany, "Blob") + + if len(identifiers) == 0 { + return nil, nil, nil + } + + var sacQueryFilter *v1.Query + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return nil, nil, nil + } + q := search.ConjunctionQuery( + sacQueryFilter, + search.NewQueryBuilder().AddDocIDs(identifiers...).ProtoQuery(), + ) + + rows, err := pgSearch.RunGetManyQueryForSchema[storage.Blob](ctx, schema, q, s.db) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + missingIndices := make([]int, 0, len(identifiers)) + for i := range identifiers { + missingIndices = append(missingIndices, i) + } + return nil, missingIndices, nil + } + return nil, nil, err + } + resultsByID := make(map[string]*storage.Blob, len(rows)) + for _, msg := range rows { + resultsByID[msg.GetName()] = msg + } + missingIndices := make([]int, 0, len(identifiers)-len(resultsByID)) + // It is important that the elems are populated in the same order as the input identifiers + // slice, since some calling code relies on that to maintain order. + elems := make([]*storage.Blob, 0, len(resultsByID)) + for i, identifier := range identifiers { + if result, ok := resultsByID[identifier]; !ok { + missingIndices = append(missingIndices, i) + } else { + elems = append(elems, result) + } + } + return elems, missingIndices, nil +} + +// GetIDs returns all the IDs for the store. +func (s *storeImpl) GetIDs(ctx context.Context) ([]string, error) { + defer metrics.SetPostgresOperationDurationTime(time.Now(), ops.GetAll, "storage.BlobIDs") + var sacQueryFilter *v1.Query + + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return nil, nil + } + result, err := pgSearch.RunSearchRequestForSchema(ctx, schema, sacQueryFilter, s.db) + if err != nil { + return nil, err + } + + identifiers := make([]string, 0, len(result)) + for _, entry := range result { + identifiers = append(identifiers, entry.ID) + } + + return identifiers, nil +} + +// Walk iterates over all of the objects in the store and applies the closure. +func (s *storeImpl) Walk(ctx context.Context, fn func(obj *storage.Blob) error) error { + var sacQueryFilter *v1.Query + scopeChecker := sac.GlobalAccessScopeChecker(ctx).AccessMode(storage.Access_READ_ACCESS).Resource(targetResource) + if !scopeChecker.IsAllowed() { + return nil + } + fetcher, closer, err := pgSearch.RunCursorQueryForSchema[storage.Blob](ctx, schema, sacQueryFilter, s.db) + if err != nil { + return err + } + defer closer() + for { + rows, err := fetcher(cursorBatchSize) + if err != nil { + return pgutils.ErrNilIfNoRows(err) + } + for _, data := range rows { + if err := fn(data); err != nil { + return err + } + } + if len(rows) != cursorBatchSize { + break + } + } + return nil +} + +//// Stubs for satisfying legacy interfaces + +// AckKeysIndexed acknowledges the passed keys were indexed. +func (s *storeImpl) AckKeysIndexed(ctx context.Context, keys ...string) error { + return nil +} + +// GetKeysToIndex returns the keys that need to be indexed. +func (s *storeImpl) GetKeysToIndex(ctx context.Context) ([]string, error) { + return nil, nil +} + +//// Interface functions - END + +//// Used for testing + +// CreateTableAndNewStore returns a new Store instance for testing. +func CreateTableAndNewStore(ctx context.Context, db *postgres.DB, gormDB *gorm.DB) Store { + pkgSchema.ApplySchemaForTable(ctx, gormDB, baseTable) + return New(db) +} + +// Destroy drops the tables associated with the target object type. +func Destroy(ctx context.Context, db *postgres.DB) { + dropTableBlobs(ctx, db) +} + +func dropTableBlobs(ctx context.Context, db *postgres.DB) { + _, _ = db.Exec(ctx, "DROP TABLE IF EXISTS blobs CASCADE") + +} + +//// Used for testing - END diff --git a/central/blob/datastore/store/postgres/store_test.go b/central/blob/datastore/store/postgres/store_test.go new file mode 100644 index 0000000000000..5ed5c7b17159e --- /dev/null +++ b/central/blob/datastore/store/postgres/store_test.go @@ -0,0 +1,118 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. + +//go:build sql_integration + +package postgres + +import ( + "context" + "testing" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/env" + "github.com/stackrox/rox/pkg/postgres/pgtest" + "github.com/stackrox/rox/pkg/sac" + "github.com/stackrox/rox/pkg/testutils" + "github.com/stretchr/testify/suite" +) + +type BlobsStoreSuite struct { + suite.Suite + store Store + testDB *pgtest.TestPostgres +} + +func TestBlobsStore(t *testing.T) { + suite.Run(t, new(BlobsStoreSuite)) +} + +func (s *BlobsStoreSuite) SetupSuite() { + s.T().Setenv(env.PostgresDatastoreEnabled.EnvVar(), "true") + + if !env.PostgresDatastoreEnabled.BooleanSetting() { + s.T().Skip("Skip postgres store tests") + s.T().SkipNow() + } + + s.testDB = pgtest.ForT(s.T()) + s.store = New(s.testDB.DB) +} + +func (s *BlobsStoreSuite) SetupTest() { + ctx := sac.WithAllAccess(context.Background()) + tag, err := s.testDB.Exec(ctx, "TRUNCATE blobs CASCADE") + s.T().Log("blobs", tag) + s.NoError(err) +} + +func (s *BlobsStoreSuite) TearDownSuite() { + s.testDB.Teardown(s.T()) +} + +func (s *BlobsStoreSuite) TestStore() { + ctx := sac.WithAllAccess(context.Background()) + + store := s.store + + blob := &storage.Blob{} + s.NoError(testutils.FullInit(blob, testutils.SimpleInitializer(), testutils.JSONFieldsFilter)) + + foundBlob, exists, err := store.Get(ctx, blob.GetName()) + s.NoError(err) + s.False(exists) + s.Nil(foundBlob) + + withNoAccessCtx := sac.WithNoAccess(ctx) + + s.NoError(store.Upsert(ctx, blob)) + foundBlob, exists, err = store.Get(ctx, blob.GetName()) + s.NoError(err) + s.True(exists) + s.Equal(blob, foundBlob) + + blobCount, err := store.Count(ctx) + s.NoError(err) + s.Equal(1, blobCount) + blobCount, err = store.Count(withNoAccessCtx) + s.NoError(err) + s.Zero(blobCount) + + blobExists, err := store.Exists(ctx, blob.GetName()) + s.NoError(err) + s.True(blobExists) + s.NoError(store.Upsert(ctx, blob)) + s.ErrorIs(store.Upsert(withNoAccessCtx, blob), sac.ErrResourceAccessDenied) + + foundBlob, exists, err = store.Get(ctx, blob.GetName()) + s.NoError(err) + s.True(exists) + s.Equal(blob, foundBlob) + + s.NoError(store.Delete(ctx, blob.GetName())) + foundBlob, exists, err = store.Get(ctx, blob.GetName()) + s.NoError(err) + s.False(exists) + s.Nil(foundBlob) + s.ErrorIs(store.Delete(withNoAccessCtx, blob.GetName()), sac.ErrResourceAccessDenied) + + var blobs []*storage.Blob + var blobIDs []string + for i := 0; i < 200; i++ { + blob := &storage.Blob{} + s.NoError(testutils.FullInit(blob, testutils.UniqueInitializer(), testutils.JSONFieldsFilter)) + blobs = append(blobs, blob) + blobIDs = append(blobIDs, blob.GetName()) + } + + s.NoError(store.UpsertMany(ctx, blobs)) + + blobCount, err = store.Count(ctx) + s.NoError(err) + s.Equal(200, blobCount) + + s.NoError(store.DeleteMany(ctx, blobIDs)) + + blobCount, err = store.Count(ctx) + s.NoError(err) + s.Equal(0, blobCount) +} diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go new file mode 100644 index 0000000000000..b3e7c88d4a245 --- /dev/null +++ b/central/blob/datastore/store/store.go @@ -0,0 +1,148 @@ +package store + +import ( + "context" + "io" + + "github.com/jackc/pgx/v4" + "github.com/pkg/errors" + "github.com/stackrox/rox/central/blob/datastore/store/postgres" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/logging" + pgPkg "github.com/stackrox/rox/pkg/postgres" +) + +var log = logging.LoggerForModule() + +// Store is the interface to interact with the storage for storage.Blob +type Store interface { + Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error + Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) +} + +type storeImpl struct { + db *pgPkg.DB + store postgres.Store +} + +// New creates a new Blob store +func New(db *pgPkg.DB) Store { + return &storeImpl{ + db: db, + store: postgres.New(db), + } +} + +func wrapRollback(ctx context.Context, tx *pgPkg.Tx, err error) error { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil { + return errors.Wrapf(rollbackErr, "rolling back due to err: %v", err) + } + return err +} + +// Upsert adds a blob to the database +func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error { + existingBlob, exists, err := s.store.Get(ctx, obj.GetName()) + if err != nil { + return err + } + tx, err := s.db.Begin(ctx) + if err != nil { + return err + } + ctx = pgPkg.ContextWithTx(ctx, tx) + + los := tx.LargeObjects() + var lo *pgx.LargeObject + if exists { + lo, err = los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeWrite) + if err != nil { + return wrapRollback(ctx, tx, errors.Wrapf(err, "opening blob with oid %d", existingBlob.GetOid())) + } + if err := lo.Truncate(0); err != nil { + return errors.Wrapf(err, "truncating blob with oid %d", existingBlob.GetOid()) + } + } else { + oid, err := los.Create(ctx, 0) + if err != nil { + return wrapRollback(ctx, tx, errors.Wrap(err, "error creating new blob")) + } + lo, err = los.Open(ctx, oid, pgx.LargeObjectModeWrite) + if err != nil { + return wrapRollback(ctx, tx, errors.Wrapf(err, "opening blob with oid %d", oid)) + } + obj.Oid = oid + } + buf := make([]byte, 1024*1024) + for { + nRead, err := reader.Read(buf) + + if nRead != 0 { + if _, err := lo.Write(buf[:nRead]); err != nil { + return wrapRollback(ctx, tx, errors.Wrap(err, "writing blob")) + } + } + + // nRead can be non-zero when err == io.EOF + if err != nil { + if err == io.EOF { + break + } + return wrapRollback(ctx, tx, errors.Wrap(err, "reading buffer to write for blob")) + } + } + if err := lo.Close(); err != nil { + return wrapRollback(ctx, tx, errors.Wrap(err, "closing large object for blob")) + } + + if err := s.store.Upsert(ctx, obj); err != nil { + return wrapRollback(ctx, tx, errors.Wrapf(err, "error upserting blob %q", obj.GetName())) + } + return tx.Commit(ctx) +} + +// Get returns a blob from the database +func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil || !exists { + return nil, exists, err + } + + tx, err := s.db.Begin(ctx) + if err != nil { + return nil, false, err + } + ctx = pgPkg.ContextWithTx(ctx, tx) + + los := tx.LargeObjects() + lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeWrite) + if err != nil { + err := errors.Wrapf(err, "error opening large object with oid %d", existingBlob.GetOid()) + return nil, false, wrapRollback(ctx, tx, err) + } + + buf := make([]byte, 1024*1024) + for { + nRead, err := lo.Read(buf) + + // nRead can be non-zero when err == io.EOF + if nRead != 0 { + if _, err := writer.Write(buf[:nRead]); err != nil { + err := errors.Wrap(err, "error writing to output") + return nil, false, wrapRollback(ctx, tx, err) + } + } + if err != nil { + if err == io.EOF { + break + } + } + } + if err := lo.Close(); err != nil { + err = errors.Wrap(err, "closing large object for blob") + return nil, false, wrapRollback(ctx, tx, err) + } + + return existingBlob, true, tx.Commit(ctx) +} diff --git a/central/blob/datastore/store/store_test.go b/central/blob/datastore/store/store_test.go new file mode 100644 index 0000000000000..77d6c28fbacd5 --- /dev/null +++ b/central/blob/datastore/store/store_test.go @@ -0,0 +1,74 @@ +//go:build sql_integration + +package store + +import ( + "bytes" + "context" + "math/rand" + "testing" + + timestamp "github.com/gogo/protobuf/types" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/postgres/pgtest" + "github.com/stackrox/rox/pkg/sac" + "github.com/stretchr/testify/suite" +) + +type BlobsStoreSuite struct { + suite.Suite + store Store + testDB *pgtest.TestPostgres +} + +func TestBlobsStore(t *testing.T) { + suite.Run(t, new(BlobsStoreSuite)) +} + +func (s *BlobsStoreSuite) SetupSuite() { + s.testDB = pgtest.ForT(s.T()) + s.store = New(s.testDB.DB) +} + +func (s *BlobsStoreSuite) SetupTest() { + ctx := sac.WithAllAccess(context.Background()) + tag, err := s.testDB.Exec(ctx, "TRUNCATE blobs CASCADE") + s.T().Log("blobs", tag) + s.NoError(err) +} + +func (s *BlobsStoreSuite) TearDownSuite() { + s.testDB.Teardown(s.T()) +} + +func (s *BlobsStoreSuite) TestStore() { + ctx := sac.WithAllAccess(context.Background()) + + insertBlob := &storage.Blob{ + Name: "test", + LastUpdated: timestamp.TimestampNow(), + ModifiedTime: timestamp.TimestampNow(), + } + + buf := &bytes.Buffer{} + _, exists, err := s.store.Get(ctx, insertBlob.GetName(), buf) + s.Require().NoError(err) + s.Require().False(exists) + + size := 1024*1024 + 16 + randomData := make([]byte, size) + _, err = rand.Read(randomData) + s.NoError(err) + + reader := bytes.NewBuffer(randomData) + + s.Require().NoError(s.store.Upsert(ctx, insertBlob, reader)) + + buf = &bytes.Buffer{} + blob, exists, err := s.store.Get(ctx, insertBlob.GetName(), buf) + s.Require().NoError(err) + s.Require().True(exists) + s.NotZero(blob.GetOid()) + s.Equal(insertBlob, blob) + s.Equal(randomData, buf.Bytes()) +} diff --git a/generated/storage/blob.pb.go b/generated/storage/blob.pb.go new file mode 100644 index 0000000000000..74882046ff782 --- /dev/null +++ b/generated/storage/blob.pb.go @@ -0,0 +1,557 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: storage/blob.proto + +package storage + +import ( + fmt "fmt" + _ "github.com/gogo/protobuf/gogoproto" + types "github.com/gogo/protobuf/types" + proto "github.com/golang/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +// Next Tag: 4 +type Blob struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty" sql:"pk"` + Oid uint32 `protobuf:"varint,2,opt,name=oid,proto3" json:"oid,omitempty"` + Checksum string `protobuf:"bytes,3,opt,name=checksum,proto3" json:"checksum,omitempty"` + LastUpdated *types.Timestamp `protobuf:"bytes,4,opt,name=last_updated,json=lastUpdated,proto3" json:"last_updated,omitempty"` + ModifiedTime *types.Timestamp `protobuf:"bytes,5,opt,name=modified_time,json=modifiedTime,proto3" json:"modified_time,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Blob) Reset() { *m = Blob{} } +func (m *Blob) String() string { return proto.CompactTextString(m) } +func (*Blob) ProtoMessage() {} +func (*Blob) Descriptor() ([]byte, []int) { + return fileDescriptor_93b63e008eb8666f, []int{0} +} +func (m *Blob) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Blob) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Blob.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Blob) XXX_Merge(src proto.Message) { + xxx_messageInfo_Blob.Merge(m, src) +} +func (m *Blob) XXX_Size() int { + return m.Size() +} +func (m *Blob) XXX_DiscardUnknown() { + xxx_messageInfo_Blob.DiscardUnknown(m) +} + +var xxx_messageInfo_Blob proto.InternalMessageInfo + +func (m *Blob) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Blob) GetOid() uint32 { + if m != nil { + return m.Oid + } + return 0 +} + +func (m *Blob) GetChecksum() string { + if m != nil { + return m.Checksum + } + return "" +} + +func (m *Blob) GetLastUpdated() *types.Timestamp { + if m != nil { + return m.LastUpdated + } + return nil +} + +func (m *Blob) GetModifiedTime() *types.Timestamp { + if m != nil { + return m.ModifiedTime + } + return nil +} + +func (m *Blob) MessageClone() proto.Message { + return m.Clone() +} +func (m *Blob) Clone() *Blob { + if m == nil { + return nil + } + cloned := new(Blob) + *cloned = *m + + cloned.LastUpdated = m.LastUpdated.Clone() + cloned.ModifiedTime = m.ModifiedTime.Clone() + return cloned +} + +func init() { + proto.RegisterType((*Blob)(nil), "storage.Blob") +} + +func init() { proto.RegisterFile("storage/blob.proto", fileDescriptor_93b63e008eb8666f) } + +var fileDescriptor_93b63e008eb8666f = []byte{ + // 273 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x2e, 0xc9, 0x2f, + 0x4a, 0x4c, 0x4f, 0xd5, 0x4f, 0xca, 0xc9, 0x4f, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, + 0x87, 0x8a, 0x49, 0x89, 0xa4, 0xe7, 0xa7, 0xe7, 0x83, 0xc5, 0xf4, 0x41, 0x2c, 0x88, 0xb4, 0x94, + 0x7c, 0x7a, 0x7e, 0x7e, 0x7a, 0x4e, 0xaa, 0x3e, 0x98, 0x97, 0x54, 0x9a, 0xa6, 0x5f, 0x92, 0x99, + 0x9b, 0x5a, 0x5c, 0x92, 0x98, 0x5b, 0x00, 0x51, 0xa0, 0x74, 0x8d, 0x91, 0x8b, 0xc5, 0x29, 0x27, + 0x3f, 0x49, 0x48, 0x81, 0x8b, 0x25, 0x2f, 0x31, 0x37, 0x55, 0x82, 0x51, 0x81, 0x51, 0x83, 0xd3, + 0x89, 0xe7, 0xd3, 0x3d, 0x79, 0x8e, 0xe2, 0xc2, 0x1c, 0x2b, 0xa5, 0x82, 0x6c, 0xa5, 0x20, 0xb0, + 0x8c, 0x90, 0x00, 0x17, 0x73, 0x7e, 0x66, 0x8a, 0x04, 0x93, 0x02, 0xa3, 0x06, 0x6f, 0x10, 0x88, + 0x29, 0x24, 0xc5, 0xc5, 0x91, 0x9c, 0x91, 0x9a, 0x9c, 0x5d, 0x5c, 0x9a, 0x2b, 0xc1, 0x0c, 0xd2, + 0x17, 0x04, 0xe7, 0x0b, 0xd9, 0x72, 0xf1, 0xe4, 0x24, 0x16, 0x97, 0xc4, 0x97, 0x16, 0xa4, 0x24, + 0x96, 0xa4, 0xa6, 0x48, 0xb0, 0x28, 0x30, 0x6a, 0x70, 0x1b, 0x49, 0xe9, 0x41, 0x1c, 0xa4, 0x07, + 0x73, 0x90, 0x5e, 0x08, 0xcc, 0x41, 0x41, 0xdc, 0x20, 0xf5, 0xa1, 0x10, 0xe5, 0x42, 0xf6, 0x5c, + 0xbc, 0xb9, 0xf9, 0x29, 0x99, 0x69, 0x99, 0xa9, 0x29, 0xf1, 0x20, 0x37, 0x4b, 0xb0, 0x12, 0xd4, + 0xcf, 0x03, 0xd3, 0x00, 0x12, 0x72, 0x32, 0x39, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, + 0x07, 0x8f, 0xe4, 0x18, 0x67, 0x3c, 0x96, 0x63, 0xe0, 0x92, 0xcc, 0xcc, 0xd7, 0x2b, 0x2e, 0x49, + 0x4c, 0xce, 0x2e, 0xca, 0xaf, 0x80, 0xe8, 0xd7, 0x83, 0x06, 0x5e, 0x14, 0x2c, 0x14, 0x93, 0xd8, + 0xc0, 0xe2, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc0, 0x63, 0xb7, 0xe3, 0x6b, 0x01, 0x00, + 0x00, +} + +func (m *Blob) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Blob) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Blob) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if m.ModifiedTime != nil { + { + size, err := m.ModifiedTime.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintBlob(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x2a + } + if m.LastUpdated != nil { + { + size, err := m.LastUpdated.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintBlob(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x22 + } + if len(m.Checksum) > 0 { + i -= len(m.Checksum) + copy(dAtA[i:], m.Checksum) + i = encodeVarintBlob(dAtA, i, uint64(len(m.Checksum))) + i-- + dAtA[i] = 0x1a + } + if m.Oid != 0 { + i = encodeVarintBlob(dAtA, i, uint64(m.Oid)) + i-- + dAtA[i] = 0x10 + } + if len(m.Name) > 0 { + i -= len(m.Name) + copy(dAtA[i:], m.Name) + i = encodeVarintBlob(dAtA, i, uint64(len(m.Name))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintBlob(dAtA []byte, offset int, v uint64) int { + offset -= sovBlob(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *Blob) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + sovBlob(uint64(l)) + } + if m.Oid != 0 { + n += 1 + sovBlob(uint64(m.Oid)) + } + l = len(m.Checksum) + if l > 0 { + n += 1 + l + sovBlob(uint64(l)) + } + if m.LastUpdated != nil { + l = m.LastUpdated.Size() + n += 1 + l + sovBlob(uint64(l)) + } + if m.ModifiedTime != nil { + l = m.ModifiedTime.Size() + n += 1 + l + sovBlob(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovBlob(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozBlob(x uint64) (n int) { + return sovBlob(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *Blob) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Blob: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Blob: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthBlob + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthBlob + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Name = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Oid", wireType) + } + m.Oid = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Oid |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Checksum", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthBlob + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthBlob + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Checksum = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field LastUpdated", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthBlob + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthBlob + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.LastUpdated == nil { + m.LastUpdated = &types.Timestamp{} + } + if err := m.LastUpdated.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ModifiedTime", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthBlob + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthBlob + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.ModifiedTime == nil { + m.ModifiedTime = &types.Timestamp{} + } + if err := m.ModifiedTime.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipBlob(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthBlob + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipBlob(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowBlob + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowBlob + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowBlob + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthBlob + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupBlob + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthBlob + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthBlob = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowBlob = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupBlob = fmt.Errorf("proto: unexpected end of group") +) diff --git a/pkg/postgres/conn.go b/pkg/postgres/conn.go index 537ffc3a51476..d040b200b5eed 100644 --- a/pkg/postgres/conn.go +++ b/pkg/postgres/conn.go @@ -42,6 +42,10 @@ func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (pgcon ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) defer cancel() + if tx, ok := TxFromContext(ctx); ok { + return tx.Exec(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "conn", sql) ct, err := c.Conn.Exec(ctx, sql, args...) if err != nil { @@ -55,6 +59,10 @@ func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (pgcon func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + if tx, ok := TxFromContext(ctx); ok { + return tx.Query(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "conn", sql) rows, err := c.Conn.Query(ctx, sql, args...) if err != nil { @@ -73,9 +81,16 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (*Row func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) *Row { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) - defer setQueryDuration(time.Now(), "conn", sql) + var row pgx.Row + if tx, ok := TxFromContext(ctx); ok { + row = tx.QueryRow(ctx, sql, args...) + } else { + defer setQueryDuration(time.Now(), "conn", sql) + row = c.Conn.QueryRow(ctx, sql, args...) + } + return &Row{ - Row: c.Conn.QueryRow(ctx, sql, args...), + Row: row, query: sql, cancelFunc: cancel, } @@ -85,8 +100,26 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) *R func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) *BatchResults { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + var batchResults pgx.BatchResults + if tx, ok := TxFromContext(ctx); ok { + batchResults = tx.SendBatch(ctx, b) + } else { + batchResults = c.Conn.SendBatch(ctx, b) + } + return &BatchResults{ - BatchResults: c.Conn.SendBatch(ctx, b), + BatchResults: batchResults, cancel: cancel, } } + +// CopyFrom wraps pgxpool.Conn CopyFrom +func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + defer cancel() + + if tx, ok := TxFromContext(ctx); ok { + return tx.CopyFrom(ctx, tableName, columnNames, rowSrc) + } + return c.Conn.CopyFrom(ctx, tableName, columnNames, rowSrc) +} diff --git a/pkg/postgres/context.go b/pkg/postgres/context.go new file mode 100644 index 0000000000000..2cf0d21727aea --- /dev/null +++ b/pkg/postgres/context.go @@ -0,0 +1,19 @@ +package postgres + +import "context" + +type txContextKey struct{} + +// ContextWithTx adds a database transaction to the context +func ContextWithTx(ctx context.Context, tx *Tx) context.Context { + return context.WithValue(ctx, txContextKey{}, tx) +} + +// TxFromContext gets a database transaction from the context if it exists +func TxFromContext(ctx context.Context) (*Tx, bool) { + obj := ctx.Value(txContextKey{}) + if obj == nil { + return nil, false + } + return obj.(*Tx), true +} diff --git a/pkg/postgres/pool.go b/pkg/postgres/pool.go index a29be65818113..cfb468aa98c7e 100644 --- a/pkg/postgres/pool.go +++ b/pkg/postgres/pool.go @@ -72,6 +72,10 @@ func (d *DB) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn. ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) defer cancel() + if tx, ok := TxFromContext(ctx); ok { + return tx.Exec(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "pool", sql) ct, err := d.Pool.Exec(ctx, sql, args...) if err != nil { @@ -85,6 +89,10 @@ func (d *DB) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn. func (d *DB) Query(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + if tx, ok := TxFromContext(ctx); ok { + return tx.Query(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "pool", sql) rows, err := d.Pool.Query(ctx, sql, args...) if err != nil { @@ -99,17 +107,34 @@ func (d *DB) Query(ctx context.Context, sql string, args ...interface{}) (*Rows, } // QueryRow wraps pgxpool.Pool QueryRow -func (d *DB) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (d *DB) QueryRow(ctx context.Context, sql string, args ...interface{}) *Row { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) - defer setQueryDuration(time.Now(), "pool", sql) + var row pgx.Row + if tx, ok := TxFromContext(ctx); ok { + row = tx.QueryRow(ctx, sql, args...) + } else { + defer setQueryDuration(time.Now(), "pool", sql) + row = d.Pool.QueryRow(ctx, sql, args...) + } return &Row{ - Row: d.Pool.QueryRow(ctx, sql, args...), + Row: row, query: sql, cancelFunc: cancel, } } +// CopyFrom wraps pgxpool.Pool CopyFrom +func (d *DB) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + defer cancel() + + if tx, ok := TxFromContext(ctx); ok { + return tx.CopyFrom(ctx, tableName, columnNames, rowSrc) + } + return d.Pool.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + // Acquire wraps pgxpool.Acquire func (d *DB) Acquire(ctx context.Context) (*Conn, error) { conn, err := d.Pool.Acquire(ctx) diff --git a/pkg/postgres/schema/blobs.go b/pkg/postgres/schema/blobs.go new file mode 100644 index 0000000000000..c73ea20ca81d7 --- /dev/null +++ b/pkg/postgres/schema/blobs.go @@ -0,0 +1,40 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. + +package schema + +import ( + "reflect" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/postgres" + "github.com/stackrox/rox/pkg/postgres/walker" +) + +var ( + // CreateTableBlobsStmt holds the create statement for table `blobs`. + CreateTableBlobsStmt = &postgres.CreateStmts{ + GormModel: (*Blobs)(nil), + Children: []*postgres.CreateStmts{}, + } + + // BlobsSchema is the go schema for table `blobs`. + BlobsSchema = func() *walker.Schema { + schema := GetSchemaForTable("blobs") + if schema != nil { + return schema + } + schema = walker.Walk(reflect.TypeOf((*storage.Blob)(nil)), "blobs") + RegisterTable(schema, CreateTableBlobsStmt) + return schema + }() +) + +const ( + BlobsTableName = "blobs" +) + +// Blobs holds the Gorm model for Postgres table `blobs`. +type Blobs struct { + Name string `gorm:"column:name;type:varchar;primaryKey"` + Serialized []byte `gorm:"column:serialized;type:bytea"` +} diff --git a/tools/generate-helpers/pg-table-bindings/list.go b/tools/generate-helpers/pg-table-bindings/list.go index 87f5a3cc5b495..28b7c17018326 100644 --- a/tools/generate-helpers/pg-table-bindings/list.go +++ b/tools/generate-helpers/pg-table-bindings/list.go @@ -19,6 +19,7 @@ func init() { for s, r := range map[proto.Message]permissions.ResourceHandle{ &storage.ActiveComponent{}: resources.Deployment, &storage.AuthProvider{}: resources.Access, + &storage.Blob{}: resources.Administration, &storage.ClusterHealthStatus{}: resources.Cluster, &storage.ClusterCVE{}: resources.Cluster, &storage.ClusterCVEEdge{}: resources.Cluster, From b77920a3b5baa2792c79d2ef19ffc84d06120eb9 Mon Sep 17 00:00:00 2001 From: Connor Gorman Date: Mon, 13 Mar 2023 21:42:42 -0700 Subject: [PATCH 02/40] Add utilities to pass transactions through contexts --- pkg/postgres/conn.go | 39 ++++++++++++++++++++++++++++++++++++--- pkg/postgres/context.go | 19 +++++++++++++++++++ pkg/postgres/pool.go | 31 ++++++++++++++++++++++++++++--- 3 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 pkg/postgres/context.go diff --git a/pkg/postgres/conn.go b/pkg/postgres/conn.go index 537ffc3a51476..d040b200b5eed 100644 --- a/pkg/postgres/conn.go +++ b/pkg/postgres/conn.go @@ -42,6 +42,10 @@ func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (pgcon ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) defer cancel() + if tx, ok := TxFromContext(ctx); ok { + return tx.Exec(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "conn", sql) ct, err := c.Conn.Exec(ctx, sql, args...) if err != nil { @@ -55,6 +59,10 @@ func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (pgcon func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + if tx, ok := TxFromContext(ctx); ok { + return tx.Query(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "conn", sql) rows, err := c.Conn.Query(ctx, sql, args...) if err != nil { @@ -73,9 +81,16 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (*Row func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) *Row { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) - defer setQueryDuration(time.Now(), "conn", sql) + var row pgx.Row + if tx, ok := TxFromContext(ctx); ok { + row = tx.QueryRow(ctx, sql, args...) + } else { + defer setQueryDuration(time.Now(), "conn", sql) + row = c.Conn.QueryRow(ctx, sql, args...) + } + return &Row{ - Row: c.Conn.QueryRow(ctx, sql, args...), + Row: row, query: sql, cancelFunc: cancel, } @@ -85,8 +100,26 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) *R func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) *BatchResults { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + var batchResults pgx.BatchResults + if tx, ok := TxFromContext(ctx); ok { + batchResults = tx.SendBatch(ctx, b) + } else { + batchResults = c.Conn.SendBatch(ctx, b) + } + return &BatchResults{ - BatchResults: c.Conn.SendBatch(ctx, b), + BatchResults: batchResults, cancel: cancel, } } + +// CopyFrom wraps pgxpool.Conn CopyFrom +func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + defer cancel() + + if tx, ok := TxFromContext(ctx); ok { + return tx.CopyFrom(ctx, tableName, columnNames, rowSrc) + } + return c.Conn.CopyFrom(ctx, tableName, columnNames, rowSrc) +} diff --git a/pkg/postgres/context.go b/pkg/postgres/context.go new file mode 100644 index 0000000000000..2cf0d21727aea --- /dev/null +++ b/pkg/postgres/context.go @@ -0,0 +1,19 @@ +package postgres + +import "context" + +type txContextKey struct{} + +// ContextWithTx adds a database transaction to the context +func ContextWithTx(ctx context.Context, tx *Tx) context.Context { + return context.WithValue(ctx, txContextKey{}, tx) +} + +// TxFromContext gets a database transaction from the context if it exists +func TxFromContext(ctx context.Context) (*Tx, bool) { + obj := ctx.Value(txContextKey{}) + if obj == nil { + return nil, false + } + return obj.(*Tx), true +} diff --git a/pkg/postgres/pool.go b/pkg/postgres/pool.go index a29be65818113..cfb468aa98c7e 100644 --- a/pkg/postgres/pool.go +++ b/pkg/postgres/pool.go @@ -72,6 +72,10 @@ func (d *DB) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn. ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) defer cancel() + if tx, ok := TxFromContext(ctx); ok { + return tx.Exec(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "pool", sql) ct, err := d.Pool.Exec(ctx, sql, args...) if err != nil { @@ -85,6 +89,10 @@ func (d *DB) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn. func (d *DB) Query(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + if tx, ok := TxFromContext(ctx); ok { + return tx.Query(ctx, sql, args...) + } + defer setQueryDuration(time.Now(), "pool", sql) rows, err := d.Pool.Query(ctx, sql, args...) if err != nil { @@ -99,17 +107,34 @@ func (d *DB) Query(ctx context.Context, sql string, args ...interface{}) (*Rows, } // QueryRow wraps pgxpool.Pool QueryRow -func (d *DB) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (d *DB) QueryRow(ctx context.Context, sql string, args ...interface{}) *Row { ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) - defer setQueryDuration(time.Now(), "pool", sql) + var row pgx.Row + if tx, ok := TxFromContext(ctx); ok { + row = tx.QueryRow(ctx, sql, args...) + } else { + defer setQueryDuration(time.Now(), "pool", sql) + row = d.Pool.QueryRow(ctx, sql, args...) + } return &Row{ - Row: d.Pool.QueryRow(ctx, sql, args...), + Row: row, query: sql, cancelFunc: cancel, } } +// CopyFrom wraps pgxpool.Pool CopyFrom +func (d *DB) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + ctx, cancel := contextutil.ContextWithTimeoutIfNotExists(ctx, defaultTimeout) + defer cancel() + + if tx, ok := TxFromContext(ctx); ok { + return tx.CopyFrom(ctx, tableName, columnNames, rowSrc) + } + return d.Pool.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + // Acquire wraps pgxpool.Acquire func (d *DB) Acquire(ctx context.Context) (*Conn, error) { conn, err := d.Pool.Acquire(ctx) From 7d4d49315ac5b41b68170e3fe265b32ff2c123e8 Mon Sep 17 00:00:00 2001 From: Connor Gorman Date: Tue, 14 Mar 2023 08:13:07 -0700 Subject: [PATCH 03/40] dont forget the proto file --- proto/storage/blob.proto | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 proto/storage/blob.proto diff --git a/proto/storage/blob.proto b/proto/storage/blob.proto new file mode 100644 index 0000000000000..96fc04fa515c8 --- /dev/null +++ b/proto/storage/blob.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +option go_package = "storage"; +option java_package = "io.stackrox.proto.storage"; + +import "gogoproto/gogo.proto"; +import "google/protobuf/timestamp.proto"; + +package storage; + +// Next Tag: 4 +message Blob { + string name = 1 [(gogoproto.moretags) = 'sql:"pk"']; + uint32 oid = 2; + string checksum = 3; + google.protobuf.Timestamp last_updated = 4; + google.protobuf.Timestamp modified_time = 5; +} From 92d1c0e295353d5ab12f6cfc94941eed243bfc01 Mon Sep 17 00:00:00 2001 From: Cong Du Date: Fri, 28 Apr 2023 09:36:47 -0700 Subject: [PATCH 04/40] stage --- central/blob/datastore/store/postgres/store.go | 10 +++++----- central/blob/datastore/store/store.go | 4 ++-- proto/storage/blob.proto | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/central/blob/datastore/store/postgres/store.go b/central/blob/datastore/store/postgres/store.go index 421213e2dd50d..f4f74b960e740 100644 --- a/central/blob/datastore/store/postgres/store.go +++ b/central/blob/datastore/store/postgres/store.go @@ -67,12 +67,12 @@ type Store interface { } type storeImpl struct { - db *postgres.DB + db postgres.DB mutex sync.RWMutex } // New returns a new Store instance using the provided sql instance. -func New(db *postgres.DB) Store { +func New(db postgres.DB) Store { return &storeImpl{ db: db, } @@ -518,17 +518,17 @@ func (s *storeImpl) GetKeysToIndex(ctx context.Context) ([]string, error) { //// Used for testing // CreateTableAndNewStore returns a new Store instance for testing. -func CreateTableAndNewStore(ctx context.Context, db *postgres.DB, gormDB *gorm.DB) Store { +func CreateTableAndNewStore(ctx context.Context, db postgres.DB, gormDB *gorm.DB) Store { pkgSchema.ApplySchemaForTable(ctx, gormDB, baseTable) return New(db) } // Destroy drops the tables associated with the target object type. -func Destroy(ctx context.Context, db *postgres.DB) { +func Destroy(ctx context.Context, db postgres.DB) { dropTableBlobs(ctx, db) } -func dropTableBlobs(ctx context.Context, db *postgres.DB) { +func dropTableBlobs(ctx context.Context, db postgres.DB) { _, _ = db.Exec(ctx, "DROP TABLE IF EXISTS blobs CASCADE") } diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index b3e7c88d4a245..d9594087de7d5 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -21,12 +21,12 @@ type Store interface { } type storeImpl struct { - db *pgPkg.DB + db pgPkg.DB store postgres.Store } // New creates a new Blob store -func New(db *pgPkg.DB) Store { +func New(db pgPkg.DB) Store { return &storeImpl{ db: db, store: postgres.New(db), diff --git a/proto/storage/blob.proto b/proto/storage/blob.proto index 96fc04fa515c8..9d97a5493d64b 100644 --- a/proto/storage/blob.proto +++ b/proto/storage/blob.proto @@ -8,7 +8,7 @@ import "google/protobuf/timestamp.proto"; package storage; -// Next Tag: 4 +// Next Tag: 6 message Blob { string name = 1 [(gogoproto.moretags) = 'sql:"pk"']; uint32 oid = 2; From d40b4ec22f8304282d741523cee7593e89848be7 Mon Sep 17 00:00:00 2001 From: Cong Du Date: Fri, 28 Apr 2023 17:57:29 -0700 Subject: [PATCH 05/40] stage --- central/blob/datastore/datastore.go | 12 +++- central/blob/datastore/store/store.go | 56 ++++++++++--------- central/scannerdefinitions/handler/handler.go | 51 ++++++++++------- 3 files changed, 71 insertions(+), 48 deletions(-) diff --git a/central/blob/datastore/datastore.go b/central/blob/datastore/datastore.go index 6c8f8360d5d1a..cb53eb61056ec 100644 --- a/central/blob/datastore/datastore.go +++ b/central/blob/datastore/datastore.go @@ -11,7 +11,8 @@ import ( // Datastore provides access to the blob store type Datastore interface { Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error - Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) + Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) + Delete(ctx context.Context, name string) error } // NewDatastore creates a new Blob datastore @@ -31,6 +32,11 @@ func (d *datastoreImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io } // Get retrieves a blob from the database -func (d *datastoreImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { - return d.store.Get(ctx, name, writer) +func (d *datastoreImpl) Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) { + return d.store.Get(ctx, name) +} + +// Delete removes a blob store from database +func (d *datastoreImpl) Delete(ctx context.Context, name string) error { + return d.store.Delete(ctx, name) } diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index d9594087de7d5..0f63dd6bf8550 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -17,7 +17,8 @@ var log = logging.LoggerForModule() // Store is the interface to interact with the storage for storage.Blob type Store interface { Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error - Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) + Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) + Delete(ctx context.Context, name string) error } type storeImpl struct { @@ -103,46 +104,51 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea } // Get returns a blob from the database -func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { +func (s *storeImpl) Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) { existingBlob, exists, err := s.store.Get(ctx, name) if err != nil || !exists { - return nil, exists, err + return nil, nil, exists, err } tx, err := s.db.Begin(ctx) if err != nil { - return nil, false, err + return nil, nil, false, err } ctx = pgPkg.ContextWithTx(ctx, tx) los := tx.LargeObjects() - lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeWrite) + lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeRead) if err != nil { err := errors.Wrapf(err, "error opening large object with oid %d", existingBlob.GetOid()) - return nil, false, wrapRollback(ctx, tx, err) + return nil, nil, false, wrapRollback(ctx, tx, err) } - buf := make([]byte, 1024*1024) - for { - nRead, err := lo.Read(buf) + return existingBlob, lo, true, tx.Commit(ctx) +} - // nRead can be non-zero when err == io.EOF - if nRead != 0 { - if _, err := writer.Write(buf[:nRead]); err != nil { - err := errors.Wrap(err, "error writing to output") - return nil, false, wrapRollback(ctx, tx, err) - } - } - if err != nil { - if err == io.EOF { - break - } - } +// Delete removes a blob store from database +func (s *storeImpl) Delete(ctx context.Context, name string) error { + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil { + return err } - if err := lo.Close(); err != nil { - err = errors.Wrap(err, "closing large object for blob") - return nil, false, wrapRollback(ctx, tx, err) + if !exists { + return nil + } + tx, err := s.db.Begin(ctx) + if err != nil { + return err } - return existingBlob, true, tx.Commit(ctx) + ctx = pgPkg.ContextWithTx(ctx, tx) + los := tx.LargeObjects() + if err = los.Unlink(ctx, existingBlob.GetOid()); err != nil { + return errors.Wrapf(err, "failed to remove large object with oid %d", existingBlob.GetOid()) + } + if err = s.store.Delete(ctx, name); err != nil { + err = errors.Wrapf(err, "deleting large object %s", name) + return wrapRollback(ctx, tx, err) + } + + return tx.Commit(ctx) } diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index 4654559ea00e5..c2359d4c13dbd 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -2,6 +2,7 @@ package handler import ( "archive/zip" + "context" "io" "io/fs" "net/http" @@ -11,15 +12,18 @@ import ( "strings" "time" + timestamp "github.com/gogo/protobuf/types" "github.com/pkg/errors" + blob "github.com/stackrox/rox/central/blob/datastore" "github.com/stackrox/rox/central/cve/fetcher" "github.com/stackrox/rox/central/scannerdefinitions/file" + "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/env" "github.com/stackrox/rox/pkg/fileutils" "github.com/stackrox/rox/pkg/httputil" "github.com/stackrox/rox/pkg/httputil/proxy" "github.com/stackrox/rox/pkg/logging" - "github.com/stackrox/rox/pkg/migrations" + "github.com/stackrox/rox/pkg/postgres/pgutils" "github.com/stackrox/rox/pkg/sync" "github.com/stackrox/rox/pkg/utils" "google.golang.org/grpc/codes" @@ -28,6 +32,9 @@ import ( const ( definitionsBaseDir = "scannerdefinitions" + scannerDefinitionBlobName = "offline.scanner.definitions" + scannerDefinitionTemp = "scanner-defs-*.zip" + // scannerDefsSubZipName represents the offline zip bundle for CVEs for Scanner. scannerDefsSubZipName = "scanner-defs.zip" // K8sIstioCveZipName represents the zip bundle for k8s/istio CVEs. @@ -67,6 +74,7 @@ type httpHandler struct { updaters map[string]*requestedUpdater onlineVulnDir string offlineFile *file.File + blobStore blob.Datastore } // New creates a new http.Handler to handle vulnerability data. @@ -74,12 +82,11 @@ func New(cveManager fetcher.OrchestratorIstioCVEManager, opts handlerOpts) http. h := &httpHandler{ cveManager: cveManager, - online: !env.OfflineModeEnv.BooleanSetting(), - interval: env.ScannerVulnUpdateInterval.DurationSetting(), + online: !env.OfflineModeEnv.BooleanSetting(), + interval: env.ScannerVulnUpdateInterval.DurationSetting(), + blobStore: blob.Singleton(), } - h.initializeOfflineVulnDump(opts.offlineVulnDefsDir) - if h.online { h.initializeUpdaters(opts.cleanupInterval, opts.cleanupAge) } else { @@ -89,14 +96,6 @@ func New(cveManager fetcher.OrchestratorIstioCVEManager, opts handlerOpts) http. return h } -func (h *httpHandler) initializeOfflineVulnDump(vulnDefsDir string) { - if vulnDefsDir == "" { - vulnDefsDir = filepath.Join(migrations.DBMountPath(), definitionsBaseDir) - } - - h.offlineFile = file.New(filepath.Join(vulnDefsDir, offlineScannerDefsName)) -} - func (h *httpHandler) initializeUpdaters(cleanupInterval, cleanupAge *time.Duration) { var err error h.onlineVulnDir, err = os.MkdirTemp("", definitionsBaseDir) @@ -212,14 +211,20 @@ func (h *httpHandler) handleScannerDefsFile(zipF *zip.File) error { defer utils.IgnoreError(r.Close) // POST requests only update the offline feed. - if err := h.offlineFile.Write(r, zipF.Modified); err != nil { + b := &storage.Blob{ + Name: scannerDefinationBlobName, + LastUpdated: timestamp.TimestampNow(), + ModifiedTime: timestamp.TimestampNow(), + } + + if err := h.blobStore.Upsert(context.Background(), b, r); err != nil { return errors.Wrap(err, "writing scanner definitions") } return nil } -func (h *httpHandler) handleZipContentsFromVulnDump(zipPath string) error { +func (h *httpHandler) handleZipContentsFromVulnDump(ctx context.Context, zipPath string) error { zipR, err := zip.OpenReader(zipPath) if err != nil { return errors.Wrap(err, "couldn't open file as zip") @@ -263,7 +268,7 @@ func (h *httpHandler) post(w http.ResponseWriter, r *http.Request) { return } - if err := h.handleZipContentsFromVulnDump(tempFile); err != nil { + if err := h.handleZipContentsFromVulnDump(r.Context(), tempFile); err != nil { httputil.WriteGRPCStyleError(w, codes.Internal, err) return } @@ -308,10 +313,16 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { // online, otherwise fallback to the manually uploaded definitions. The file // object can be `nil` if the definitions file does not exist, rather than // returning an error. -func (h *httpHandler) openMostRecentDefinitions(uuid string) (file *os.File, modTime time.Time, err error) { +func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc io.ReadCloser, modTime time.Time, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - file, modTime, err = h.offlineFile.Open() + var blob *storage.Blob + blob, rc, _, err = h.blobStore.Get(context.Background(), scannerDefinitionBlobName) + if err != nil { + err = errors.Wrapf(err, "failed to open offline scanner definition bundle") + return + } + modTime = *pgutils.NilOrTime(blob.GetLastUpdated()) return } @@ -336,10 +347,10 @@ func (h *httpHandler) openMostRecentDefinitions(uuid string) (file *os.File, mod // since modification time will be zero. if offlineTime.After(onlineTime) { - file, modTime = offlineFile, offlineTime + rc, modTime = offlineFile, offlineTime utils.IgnoreError(onlineFile.Close) } else { - file, modTime = onlineFile, onlineTime + rc, modTime = onlineFile, onlineTime utils.IgnoreError(offlineFile.Close) } From 4384c9ad1b04ef0bd9638ef1a0aed40541d89623 Mon Sep 17 00:00:00 2001 From: Cong Du Date: Mon, 1 May 2023 09:40:54 -0700 Subject: [PATCH 06/40] Add delete --- central/blob/datastore/datastore.go | 6 +++++ central/blob/datastore/store/store.go | 31 +++++++++++++++++++++- central/blob/datastore/store/store_test.go | 2 ++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/central/blob/datastore/datastore.go b/central/blob/datastore/datastore.go index 6c8f8360d5d1a..116d7ec224510 100644 --- a/central/blob/datastore/datastore.go +++ b/central/blob/datastore/datastore.go @@ -12,6 +12,7 @@ import ( type Datastore interface { Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) + Delete(ctx context.Context, name string) error } // NewDatastore creates a new Blob datastore @@ -34,3 +35,8 @@ func (d *datastoreImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io func (d *datastoreImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { return d.store.Get(ctx, name, writer) } + +// Delete removes a blob store from database +func (d *datastoreImpl) Delete(ctx context.Context, name string) error { + return d.store.Delete(ctx, name) +} diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index d9594087de7d5..0b98f14410803 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -18,6 +18,7 @@ var log = logging.LoggerForModule() type Store interface { Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) + Delete(ctx context.Context, name string) error } type storeImpl struct { @@ -116,7 +117,7 @@ func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*st ctx = pgPkg.ContextWithTx(ctx, tx) los := tx.LargeObjects() - lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeWrite) + lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeRead) if err != nil { err := errors.Wrapf(err, "error opening large object with oid %d", existingBlob.GetOid()) return nil, false, wrapRollback(ctx, tx, err) @@ -146,3 +147,31 @@ func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*st return existingBlob, true, tx.Commit(ctx) } + +// Delete removes a blob from database if it exists +func (s *storeImpl) Delete(ctx context.Context, name string) error { + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil { + return err + } + if !exists { + return nil + } + + tx, err := s.db.Begin(ctx) + if err != nil { + return err + } + + ctx = pgPkg.ContextWithTx(ctx, tx) + los := tx.LargeObjects() + if err = los.Unlink(ctx, existingBlob.GetOid()); err != nil { + return errors.Wrapf(err, "failed to remove large object with oid %d", existingBlob.GetOid()) + } + if err = s.store.Delete(ctx, name); err != nil { + err = errors.Wrapf(err, "deleting large object %s", name) + return wrapRollback(ctx, tx, err) + } + + return tx.Commit(ctx) +} diff --git a/central/blob/datastore/store/store_test.go b/central/blob/datastore/store/store_test.go index 77d6c28fbacd5..9c7b48466aa1f 100644 --- a/central/blob/datastore/store/store_test.go +++ b/central/blob/datastore/store/store_test.go @@ -71,4 +71,6 @@ func (s *BlobsStoreSuite) TestStore() { s.NotZero(blob.GetOid()) s.Equal(insertBlob, blob) s.Equal(randomData, buf.Bytes()) + + s.NoError(s.store.Delete(ctx, insertBlob.GetName())) } From 8d2a24b1e9f5e33ebafc60e9ea8aa002b37fecc6 Mon Sep 17 00:00:00 2001 From: Cong Du Date: Mon, 1 May 2023 11:17:53 -0700 Subject: [PATCH 07/40] regen --- generated/storage/blob.pb.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generated/storage/blob.pb.go b/generated/storage/blob.pb.go index 74882046ff782..e27e987691b2b 100644 --- a/generated/storage/blob.pb.go +++ b/generated/storage/blob.pb.go @@ -24,7 +24,7 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package -// Next Tag: 4 +// Next Tag: 6 type Blob struct { Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty" sql:"pk"` Oid uint32 `protobuf:"varint,2,opt,name=oid,proto3" json:"oid,omitempty"` From 96cfee944c52d39cc0b7ef84e4478c9bf5a5dfe6 Mon Sep 17 00:00:00 2001 From: cdu Date: Mon, 8 May 2023 20:32:22 -0700 Subject: [PATCH 08/40] Address review comments --- central/blob/datastore/store/store.go | 36 +++++----- central/blob/datastore/store/store_test.go | 23 ++++++ generated/storage/blob.pb.go | 84 +++++++++++++++------- proto/storage/blob.proto | 7 +- 4 files changed, 104 insertions(+), 46 deletions(-) diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index 0b98f14410803..6e94cbb00a667 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -44,15 +44,15 @@ func wrapRollback(ctx context.Context, tx *pgPkg.Tx, err error) error { // Upsert adds a blob to the database func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error { - existingBlob, exists, err := s.store.Get(ctx, obj.GetName()) - if err != nil { - return err - } tx, err := s.db.Begin(ctx) if err != nil { return err } ctx = pgPkg.ContextWithTx(ctx, tx) + existingBlob, exists, err := s.store.Get(ctx, obj.GetName()) + if err != nil { + return wrapRollback(ctx, tx, err) + } los := tx.LargeObjects() var lo *pgx.LargeObject @@ -62,7 +62,7 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea return wrapRollback(ctx, tx, errors.Wrapf(err, "opening blob with oid %d", existingBlob.GetOid())) } if err := lo.Truncate(0); err != nil { - return errors.Wrapf(err, "truncating blob with oid %d", existingBlob.GetOid()) + return wrapRollback(ctx, tx, errors.Wrapf(err, "truncating blob with oid %d", existingBlob.GetOid())) } } else { oid, err := los.Create(ctx, 0) @@ -105,17 +105,17 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea // Get returns a blob from the database func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { - existingBlob, exists, err := s.store.Get(ctx, name) - if err != nil || !exists { - return nil, exists, err - } - tx, err := s.db.Begin(ctx) if err != nil { return nil, false, err } ctx = pgPkg.ContextWithTx(ctx, tx) + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil || !exists { + return nil, exists, wrapRollback(ctx, tx, err) + } + los := tx.LargeObjects() lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeRead) if err != nil { @@ -150,23 +150,21 @@ func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*st // Delete removes a blob from database if it exists func (s *storeImpl) Delete(ctx context.Context, name string) error { - existingBlob, exists, err := s.store.Get(ctx, name) - if err != nil { - return err - } - if !exists { - return nil - } - tx, err := s.db.Begin(ctx) if err != nil { return err } ctx = pgPkg.ContextWithTx(ctx, tx) + + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil || !exists { + return wrapRollback(ctx, tx, err) + } + los := tx.LargeObjects() if err = los.Unlink(ctx, existingBlob.GetOid()); err != nil { - return errors.Wrapf(err, "failed to remove large object with oid %d", existingBlob.GetOid()) + return wrapRollback(ctx, tx, errors.Wrapf(err, "failed to remove large object with oid %d", existingBlob.GetOid())) } if err = s.store.Delete(ctx, name); err != nil { err = errors.Wrapf(err, "deleting large object %s", name) diff --git a/central/blob/datastore/store/store_test.go b/central/blob/datastore/store/store_test.go index 9c7b48466aa1f..5ecfa14602b9c 100644 --- a/central/blob/datastore/store/store_test.go +++ b/central/blob/datastore/store/store_test.go @@ -69,8 +69,31 @@ func (s *BlobsStoreSuite) TestStore() { s.Require().NoError(err) s.Require().True(exists) s.NotZero(blob.GetOid()) + s.verifyLargeObjectCounts(1) s.Equal(insertBlob, blob) s.Equal(randomData, buf.Bytes()) s.NoError(s.store.Delete(ctx, insertBlob.GetName())) + + buf.Truncate(0) + blob, exists, err = s.store.Get(ctx, insertBlob.GetName(), buf) + s.Require().NoError(err) + s.Require().False(exists) + s.Zero(blob.GetOid()) + s.Nil(blob) + s.Zero(buf.Len()) + s.verifyLargeObjectCounts(0) +} + +func (s *BlobsStoreSuite) verifyLargeObjectCounts(expected int) { + ctx := context.Background() + tx, err := s.testDB.DB.Begin(context.Background()) + s.Require().NoError(err) + + defer func() { _ = tx.Rollback(ctx) }() + + var n int + err = tx.QueryRow(ctx, "SELECT COUNT(*) FROM pg_largeobject_metadata;").Scan(&n) + s.NoError(err) + s.Require().Equal(expected, n) } diff --git a/generated/storage/blob.pb.go b/generated/storage/blob.pb.go index e27e987691b2b..7649556477532 100644 --- a/generated/storage/blob.pb.go +++ b/generated/storage/blob.pb.go @@ -24,13 +24,14 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package -// Next Tag: 6 +// Next Tag: 7 type Blob struct { Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty" sql:"pk"` Oid uint32 `protobuf:"varint,2,opt,name=oid,proto3" json:"oid,omitempty"` Checksum string `protobuf:"bytes,3,opt,name=checksum,proto3" json:"checksum,omitempty"` - LastUpdated *types.Timestamp `protobuf:"bytes,4,opt,name=last_updated,json=lastUpdated,proto3" json:"last_updated,omitempty"` - ModifiedTime *types.Timestamp `protobuf:"bytes,5,opt,name=modified_time,json=modifiedTime,proto3" json:"modified_time,omitempty"` + Length int64 `protobuf:"varint,4,opt,name=length,proto3" json:"length,omitempty"` + LastUpdated *types.Timestamp `protobuf:"bytes,5,opt,name=last_updated,json=lastUpdated,proto3" json:"last_updated,omitempty"` + ModifiedTime *types.Timestamp `protobuf:"bytes,6,opt,name=modified_time,json=modifiedTime,proto3" json:"modified_time,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -90,6 +91,13 @@ func (m *Blob) GetChecksum() string { return "" } +func (m *Blob) GetLength() int64 { + if m != nil { + return m.Length + } + return 0 +} + func (m *Blob) GetLastUpdated() *types.Timestamp { if m != nil { return m.LastUpdated @@ -126,24 +134,25 @@ func init() { func init() { proto.RegisterFile("storage/blob.proto", fileDescriptor_93b63e008eb8666f) } var fileDescriptor_93b63e008eb8666f = []byte{ - // 273 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x2e, 0xc9, 0x2f, - 0x4a, 0x4c, 0x4f, 0xd5, 0x4f, 0xca, 0xc9, 0x4f, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, - 0x87, 0x8a, 0x49, 0x89, 0xa4, 0xe7, 0xa7, 0xe7, 0x83, 0xc5, 0xf4, 0x41, 0x2c, 0x88, 0xb4, 0x94, - 0x7c, 0x7a, 0x7e, 0x7e, 0x7a, 0x4e, 0xaa, 0x3e, 0x98, 0x97, 0x54, 0x9a, 0xa6, 0x5f, 0x92, 0x99, - 0x9b, 0x5a, 0x5c, 0x92, 0x98, 0x5b, 0x00, 0x51, 0xa0, 0x74, 0x8d, 0x91, 0x8b, 0xc5, 0x29, 0x27, - 0x3f, 0x49, 0x48, 0x81, 0x8b, 0x25, 0x2f, 0x31, 0x37, 0x55, 0x82, 0x51, 0x81, 0x51, 0x83, 0xd3, - 0x89, 0xe7, 0xd3, 0x3d, 0x79, 0x8e, 0xe2, 0xc2, 0x1c, 0x2b, 0xa5, 0x82, 0x6c, 0xa5, 0x20, 0xb0, - 0x8c, 0x90, 0x00, 0x17, 0x73, 0x7e, 0x66, 0x8a, 0x04, 0x93, 0x02, 0xa3, 0x06, 0x6f, 0x10, 0x88, - 0x29, 0x24, 0xc5, 0xc5, 0x91, 0x9c, 0x91, 0x9a, 0x9c, 0x5d, 0x5c, 0x9a, 0x2b, 0xc1, 0x0c, 0xd2, - 0x17, 0x04, 0xe7, 0x0b, 0xd9, 0x72, 0xf1, 0xe4, 0x24, 0x16, 0x97, 0xc4, 0x97, 0x16, 0xa4, 0x24, - 0x96, 0xa4, 0xa6, 0x48, 0xb0, 0x28, 0x30, 0x6a, 0x70, 0x1b, 0x49, 0xe9, 0x41, 0x1c, 0xa4, 0x07, - 0x73, 0x90, 0x5e, 0x08, 0xcc, 0x41, 0x41, 0xdc, 0x20, 0xf5, 0xa1, 0x10, 0xe5, 0x42, 0xf6, 0x5c, - 0xbc, 0xb9, 0xf9, 0x29, 0x99, 0x69, 0x99, 0xa9, 0x29, 0xf1, 0x20, 0x37, 0x4b, 0xb0, 0x12, 0xd4, - 0xcf, 0x03, 0xd3, 0x00, 0x12, 0x72, 0x32, 0x39, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, - 0x07, 0x8f, 0xe4, 0x18, 0x67, 0x3c, 0x96, 0x63, 0xe0, 0x92, 0xcc, 0xcc, 0xd7, 0x2b, 0x2e, 0x49, - 0x4c, 0xce, 0x2e, 0xca, 0xaf, 0x80, 0xe8, 0xd7, 0x83, 0x06, 0x5e, 0x14, 0x2c, 0x14, 0x93, 0xd8, - 0xc0, 0xe2, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc0, 0x63, 0xb7, 0xe3, 0x6b, 0x01, 0x00, + // 289 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x90, 0xcf, 0x4a, 0xf4, 0x30, + 0x14, 0xc5, 0xbf, 0x7c, 0x33, 0x8e, 0x63, 0xa6, 0x03, 0x12, 0x44, 0x62, 0x17, 0x9d, 0xd2, 0x55, + 0x57, 0x29, 0xa8, 0x2b, 0x41, 0x84, 0x3e, 0x42, 0xd1, 0x8d, 0x9b, 0x21, 0x6d, 0x33, 0x99, 0xd2, + 0x74, 0x6e, 0x6d, 0x52, 0xf0, 0x51, 0x7c, 0x24, 0x97, 0x3e, 0x81, 0x48, 0xdd, 0xbb, 0xf0, 0x09, + 0x24, 0xfd, 0xe3, 0xd6, 0xdd, 0x3d, 0xe7, 0x9e, 0x1f, 0x9c, 0x7b, 0x31, 0xd1, 0x06, 0x1a, 0x2e, + 0x45, 0x94, 0x2a, 0x48, 0x59, 0xdd, 0x80, 0x01, 0x72, 0x3c, 0x7a, 0xee, 0x99, 0x04, 0x09, 0xbd, + 0x17, 0xd9, 0x69, 0x58, 0xbb, 0x1b, 0x09, 0x20, 0x95, 0x88, 0x7a, 0x95, 0xb6, 0xbb, 0xc8, 0x14, + 0x95, 0xd0, 0x86, 0x57, 0xf5, 0x10, 0x08, 0xbe, 0x10, 0x9e, 0xc7, 0x0a, 0x52, 0xe2, 0xe3, 0xf9, + 0x81, 0x57, 0x82, 0x22, 0x1f, 0x85, 0x27, 0xb1, 0xf3, 0xfd, 0xbe, 0x59, 0xea, 0x27, 0x75, 0x13, + 0xd4, 0x65, 0x90, 0xf4, 0x1b, 0x72, 0x8a, 0x67, 0x50, 0xe4, 0xf4, 0xbf, 0x8f, 0xc2, 0x75, 0x62, + 0x47, 0xe2, 0xe2, 0x65, 0xb6, 0x17, 0x59, 0xa9, 0xdb, 0x8a, 0xce, 0x2c, 0x97, 0xfc, 0x6a, 0x72, + 0x8e, 0x17, 0x4a, 0x1c, 0xa4, 0xd9, 0xd3, 0xb9, 0x8f, 0xc2, 0x59, 0x32, 0x2a, 0x72, 0x8b, 0x1d, + 0xc5, 0xb5, 0xd9, 0xb6, 0x75, 0xce, 0x8d, 0xc8, 0xe9, 0x91, 0x8f, 0xc2, 0xd5, 0xa5, 0xcb, 0x86, + 0xa2, 0x6c, 0x2a, 0xca, 0xee, 0xa7, 0xa2, 0xc9, 0xca, 0xe6, 0x1f, 0x86, 0x38, 0xb9, 0xc3, 0xeb, + 0x0a, 0xf2, 0x62, 0x57, 0x88, 0x7c, 0x6b, 0x6f, 0xa1, 0x8b, 0x3f, 0x79, 0x67, 0x02, 0xac, 0x15, + 0x5f, 0xbf, 0x76, 0x1e, 0x7a, 0xeb, 0x3c, 0xf4, 0xd1, 0x79, 0xe8, 0xe5, 0xd3, 0xfb, 0x87, 0x2f, + 0x0a, 0x60, 0xda, 0xf0, 0xac, 0x6c, 0xe0, 0x79, 0xe0, 0xd9, 0xf8, 0xd4, 0xc7, 0xe9, 0xbb, 0xe9, + 0xa2, 0xf7, 0xaf, 0x7e, 0x02, 0x00, 0x00, 0xff, 0xff, 0x90, 0x14, 0xe3, 0x5a, 0x83, 0x01, 0x00, 0x00, } @@ -181,7 +190,7 @@ func (m *Blob) MarshalToSizedBuffer(dAtA []byte) (int, error) { i = encodeVarintBlob(dAtA, i, uint64(size)) } i-- - dAtA[i] = 0x2a + dAtA[i] = 0x32 } if m.LastUpdated != nil { { @@ -193,7 +202,12 @@ func (m *Blob) MarshalToSizedBuffer(dAtA []byte) (int, error) { i = encodeVarintBlob(dAtA, i, uint64(size)) } i-- - dAtA[i] = 0x22 + dAtA[i] = 0x2a + } + if m.Length != 0 { + i = encodeVarintBlob(dAtA, i, uint64(m.Length)) + i-- + dAtA[i] = 0x20 } if len(m.Checksum) > 0 { i -= len(m.Checksum) @@ -245,6 +259,9 @@ func (m *Blob) Size() (n int) { if l > 0 { n += 1 + l + sovBlob(uint64(l)) } + if m.Length != 0 { + n += 1 + sovBlob(uint64(m.Length)) + } if m.LastUpdated != nil { l = m.LastUpdated.Size() n += 1 + l + sovBlob(uint64(l)) @@ -378,6 +395,25 @@ func (m *Blob) Unmarshal(dAtA []byte) error { m.Checksum = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Length", wireType) + } + m.Length = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowBlob + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Length |= int64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 5: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field LastUpdated", wireType) } @@ -413,7 +449,7 @@ func (m *Blob) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex - case 5: + case 6: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field ModifiedTime", wireType) } diff --git a/proto/storage/blob.proto b/proto/storage/blob.proto index 9d97a5493d64b..c8681e62f4cc5 100644 --- a/proto/storage/blob.proto +++ b/proto/storage/blob.proto @@ -8,11 +8,12 @@ import "google/protobuf/timestamp.proto"; package storage; -// Next Tag: 6 +// Next Tag: 7 message Blob { string name = 1 [(gogoproto.moretags) = 'sql:"pk"']; uint32 oid = 2; string checksum = 3; - google.protobuf.Timestamp last_updated = 4; - google.protobuf.Timestamp modified_time = 5; + int64 length = 4; + google.protobuf.Timestamp last_updated = 5; + google.protobuf.Timestamp modified_time = 6; } From 13c3efd1ef23d4b5dc062a84e2ee377bcc67dd90 Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 9 May 2023 18:16:21 -0700 Subject: [PATCH 09/40] no blob --- central/blob/datastore/datastore.go | 6 +- central/blob/datastore/store/store.go | 69 ++++++++++++++-------- central/blob/datastore/store/store_test.go | 25 ++++++++ 3 files changed, 73 insertions(+), 27 deletions(-) diff --git a/central/blob/datastore/datastore.go b/central/blob/datastore/datastore.go index cb53eb61056ec..116d7ec224510 100644 --- a/central/blob/datastore/datastore.go +++ b/central/blob/datastore/datastore.go @@ -11,7 +11,7 @@ import ( // Datastore provides access to the blob store type Datastore interface { Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error - Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) + Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) Delete(ctx context.Context, name string) error } @@ -32,8 +32,8 @@ func (d *datastoreImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io } // Get retrieves a blob from the database -func (d *datastoreImpl) Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) { - return d.store.Get(ctx, name) +func (d *datastoreImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { + return d.store.Get(ctx, name, writer) } // Delete removes a blob store from database diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index 0f63dd6bf8550..6e94cbb00a667 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -17,7 +17,7 @@ var log = logging.LoggerForModule() // Store is the interface to interact with the storage for storage.Blob type Store interface { Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error - Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) + Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) Delete(ctx context.Context, name string) error } @@ -44,15 +44,15 @@ func wrapRollback(ctx context.Context, tx *pgPkg.Tx, err error) error { // Upsert adds a blob to the database func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Reader) error { - existingBlob, exists, err := s.store.Get(ctx, obj.GetName()) - if err != nil { - return err - } tx, err := s.db.Begin(ctx) if err != nil { return err } ctx = pgPkg.ContextWithTx(ctx, tx) + existingBlob, exists, err := s.store.Get(ctx, obj.GetName()) + if err != nil { + return wrapRollback(ctx, tx, err) + } los := tx.LargeObjects() var lo *pgx.LargeObject @@ -62,7 +62,7 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea return wrapRollback(ctx, tx, errors.Wrapf(err, "opening blob with oid %d", existingBlob.GetOid())) } if err := lo.Truncate(0); err != nil { - return errors.Wrapf(err, "truncating blob with oid %d", existingBlob.GetOid()) + return wrapRollback(ctx, tx, errors.Wrapf(err, "truncating blob with oid %d", existingBlob.GetOid())) } } else { oid, err := los.Create(ctx, 0) @@ -104,46 +104,67 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea } // Get returns a blob from the database -func (s *storeImpl) Get(ctx context.Context, name string) (*storage.Blob, io.ReadCloser, bool, error) { - existingBlob, exists, err := s.store.Get(ctx, name) - if err != nil || !exists { - return nil, nil, exists, err - } - +func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*storage.Blob, bool, error) { tx, err := s.db.Begin(ctx) if err != nil { - return nil, nil, false, err + return nil, false, err } ctx = pgPkg.ContextWithTx(ctx, tx) + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil || !exists { + return nil, exists, wrapRollback(ctx, tx, err) + } + los := tx.LargeObjects() lo, err := los.Open(ctx, existingBlob.GetOid(), pgx.LargeObjectModeRead) if err != nil { err := errors.Wrapf(err, "error opening large object with oid %d", existingBlob.GetOid()) - return nil, nil, false, wrapRollback(ctx, tx, err) + return nil, false, wrapRollback(ctx, tx, err) } - return existingBlob, lo, true, tx.Commit(ctx) -} + buf := make([]byte, 1024*1024) + for { + nRead, err := lo.Read(buf) -// Delete removes a blob store from database -func (s *storeImpl) Delete(ctx context.Context, name string) error { - existingBlob, exists, err := s.store.Get(ctx, name) - if err != nil { - return err + // nRead can be non-zero when err == io.EOF + if nRead != 0 { + if _, err := writer.Write(buf[:nRead]); err != nil { + err := errors.Wrap(err, "error writing to output") + return nil, false, wrapRollback(ctx, tx, err) + } + } + if err != nil { + if err == io.EOF { + break + } + } } - if !exists { - return nil + if err := lo.Close(); err != nil { + err = errors.Wrap(err, "closing large object for blob") + return nil, false, wrapRollback(ctx, tx, err) } + + return existingBlob, true, tx.Commit(ctx) +} + +// Delete removes a blob from database if it exists +func (s *storeImpl) Delete(ctx context.Context, name string) error { tx, err := s.db.Begin(ctx) if err != nil { return err } ctx = pgPkg.ContextWithTx(ctx, tx) + + existingBlob, exists, err := s.store.Get(ctx, name) + if err != nil || !exists { + return wrapRollback(ctx, tx, err) + } + los := tx.LargeObjects() if err = los.Unlink(ctx, existingBlob.GetOid()); err != nil { - return errors.Wrapf(err, "failed to remove large object with oid %d", existingBlob.GetOid()) + return wrapRollback(ctx, tx, errors.Wrapf(err, "failed to remove large object with oid %d", existingBlob.GetOid())) } if err = s.store.Delete(ctx, name); err != nil { err = errors.Wrapf(err, "deleting large object %s", name) diff --git a/central/blob/datastore/store/store_test.go b/central/blob/datastore/store/store_test.go index 77d6c28fbacd5..5ecfa14602b9c 100644 --- a/central/blob/datastore/store/store_test.go +++ b/central/blob/datastore/store/store_test.go @@ -69,6 +69,31 @@ func (s *BlobsStoreSuite) TestStore() { s.Require().NoError(err) s.Require().True(exists) s.NotZero(blob.GetOid()) + s.verifyLargeObjectCounts(1) s.Equal(insertBlob, blob) s.Equal(randomData, buf.Bytes()) + + s.NoError(s.store.Delete(ctx, insertBlob.GetName())) + + buf.Truncate(0) + blob, exists, err = s.store.Get(ctx, insertBlob.GetName(), buf) + s.Require().NoError(err) + s.Require().False(exists) + s.Zero(blob.GetOid()) + s.Nil(blob) + s.Zero(buf.Len()) + s.verifyLargeObjectCounts(0) +} + +func (s *BlobsStoreSuite) verifyLargeObjectCounts(expected int) { + ctx := context.Background() + tx, err := s.testDB.DB.Begin(context.Background()) + s.Require().NoError(err) + + defer func() { _ = tx.Rollback(ctx) }() + + var n int + err = tx.QueryRow(ctx, "SELECT COUNT(*) FROM pg_largeobject_metadata;").Scan(&n) + s.NoError(err) + s.Require().Equal(expected, n) } From f9d3511b868551ff4dbe2dc235ef4ebf4d4dbe83 Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 9 May 2023 22:36:21 -0700 Subject: [PATCH 10/40] stage --- .../handler/definition_file.go | 11 ++++ central/scannerdefinitions/handler/handler.go | 54 +++++++++++++++---- pkg/fileutils/temp_file.go | 34 ++++++++++++ 3 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 central/scannerdefinitions/handler/definition_file.go create mode 100644 pkg/fileutils/temp_file.go diff --git a/central/scannerdefinitions/handler/definition_file.go b/central/scannerdefinitions/handler/definition_file.go new file mode 100644 index 0000000000000..ac4e0ce84bf09 --- /dev/null +++ b/central/scannerdefinitions/handler/definition_file.go @@ -0,0 +1,11 @@ +package handler + +import "io" + +// definitionFileReader unifies online and offline reader for vuln definitions +type definitionFileReader interface { + io.Reader + io.Closer + io.Seeker + Name() string +} diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index c2359d4c13dbd..c464fdc7a0610 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -212,9 +212,10 @@ func (h *httpHandler) handleScannerDefsFile(zipF *zip.File) error { // POST requests only update the offline feed. b := &storage.Blob{ - Name: scannerDefinationBlobName, + Name: scannerDefinitionBlobName, LastUpdated: timestamp.TimestampNow(), ModifiedTime: timestamp.TimestampNow(), + Length: zipF.FileInfo().Size(), } if err := h.blobStore.Upsert(context.Background(), b, r); err != nil { @@ -224,7 +225,7 @@ func (h *httpHandler) handleScannerDefsFile(zipF *zip.File) error { return nil } -func (h *httpHandler) handleZipContentsFromVulnDump(ctx context.Context, zipPath string) error { +func (h *httpHandler) handleZipContentsFromVulnDump(zipPath string) error { zipR, err := zip.OpenReader(zipPath) if err != nil { return errors.Wrap(err, "couldn't open file as zip") @@ -268,7 +269,7 @@ func (h *httpHandler) post(w http.ResponseWriter, r *http.Request) { return } - if err := h.handleZipContentsFromVulnDump(r.Context(), tempFile); err != nil { + if err := h.handleZipContentsFromVulnDump(tempFile); err != nil { httputil.WriteGRPCStyleError(w, codes.Internal, err) return } @@ -313,16 +314,10 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { // online, otherwise fallback to the manually uploaded definitions. The file // object can be `nil` if the definitions file does not exist, rather than // returning an error. -func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc io.ReadCloser, modTime time.Time, err error) { +func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc definitionFileReader, modTime time.Time, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - var blob *storage.Blob - blob, rc, _, err = h.blobStore.Get(context.Background(), scannerDefinitionBlobName) - if err != nil { - err = errors.Wrapf(err, "failed to open offline scanner definition bundle") - return - } - modTime = *pgutils.NilOrTime(blob.GetLastUpdated()) + rc, modTime, err = h.open() return } @@ -409,3 +404,40 @@ func openFromArchive(archiveFile string, fileName string) (*os.File, error) { } return tmpFile, nil } + +func (h *httpHandler) open() (rnc definitionFileReader, modTime time.Time, err error) { + var tempDir string + tempDir, err = os.MkdirTemp("", "scanner-definitions-") + if err != nil { + return + } + defer func() { + if err != nil { + _ = os.RemoveAll(tempDir) + } + }() + + tempFile := filepath.Join(tempDir, "tempfile.zip") + var writer *os.File + writer, err = os.Create(tempFile) + if err != nil { + return + } + + var blob *storage.Blob + blob, _, err = h.blobStore.Get(context.Background(), scannerDefinitionBlobName, writer) + if err != nil { + err = errors.Wrapf(err, "failed to open offline scanner definition bundle") + return + } + err = writer.Close() + if err != nil { + return + } + rnc, err = fileutils.CreateTempFile(tempDir, scannerDefsSubZipName) + if err != nil { + return + } + modTime = *pgutils.NilOrTime(blob.GetLastUpdated()) + return +} diff --git a/pkg/fileutils/temp_file.go b/pkg/fileutils/temp_file.go new file mode 100644 index 0000000000000..4f1d16c1976bd --- /dev/null +++ b/pkg/fileutils/temp_file.go @@ -0,0 +1,34 @@ +package fileutils + +import ( + "os" + "path" +) + +// CreateTempFile creates a temp dir with a file. The file and its temp dir will be removed on closure. +func CreateTempFile(p string, name string) (*tempFile, error) { + file, err := os.Open(path.Join(p, name)) + if err != nil { + return nil, err + } + return &tempFile{File: file, path: p, name: name}, nil +} + +type tempFile struct { + *os.File + path string + name string +} + +// Close temp file and remove its temp dir. +func (f *tempFile) Close() error { + err := f.File.Close() + if removeErr := os.RemoveAll(f.path); err != nil { + log.Errorf("failed to remove %q: %v", f.path, removeErr) + } + return err +} + +func (f *tempFile) Name() string { + return f.name +} From e1129f4a5e992aac21360b7043028d61a8aa89e0 Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 10 May 2023 00:10:24 -0700 Subject: [PATCH 11/40] stage --- central/blob/blobfile/blob_file.go | 60 +++++++++++++++++++ central/blob/blobfile/utils.go | 47 +++++++++++++++ central/scannerdefinitions/handler/handler.go | 52 ++++------------ pkg/fileutils/temp_file.go | 34 ----------- 4 files changed, 119 insertions(+), 74 deletions(-) create mode 100644 central/blob/blobfile/blob_file.go create mode 100644 central/blob/blobfile/utils.go delete mode 100644 pkg/fileutils/temp_file.go diff --git a/central/blob/blobfile/blob_file.go b/central/blob/blobfile/blob_file.go new file mode 100644 index 0000000000000..94deaeffd7c19 --- /dev/null +++ b/central/blob/blobfile/blob_file.go @@ -0,0 +1,60 @@ +package blobfile + +import ( + "io" + "os" + "path" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/logging" +) + +var ( + log = logging.LoggerForModule() +) + +type ReadOnlyBlobFile interface { + io.Reader + io.Closer + io.Seeker + Name() string + + GetBlob() *storage.Blob +} + +// CreateBlobFile creates a temp dir with a file. The file and its temp dir will be removed on closure. +func CreateBlobFile(p string, blob *storage.Blob) (*blobFile, error) { + file, err := os.Open(p) + if err != nil { + return nil, err + } + return &blobFile{File: file, blob: blob, path: p}, nil +} + +type blobFile struct { + *os.File + blob *storage.Blob + path string +} + +// Close temp file and remove its temp dir. +func (f *blobFile) Close() error { + err := f.File.Close() + if removeErr := os.RemoveAll(f.path); err != nil { + log.Errorf("failed to remove %q: %v", f.path, removeErr) + } + return err +} + +func (f *blobFile) Name() string { + name := path.Base(f.blob.GetName()) + if name == "." || name == "/" { + return "noname" + } + return name +} + +func (f *blobFile) GetBlob() *storage.Blob { + // return f.blob.Clone() + return f.blob +} diff --git a/central/blob/blobfile/utils.go b/central/blob/blobfile/utils.go new file mode 100644 index 0000000000000..57b94cccdf6f5 --- /dev/null +++ b/central/blob/blobfile/utils.go @@ -0,0 +1,47 @@ +package blobfile + +import ( + "context" + "os" + "path/filepath" + + "github.com/pkg/errors" + "github.com/stackrox/rox/central/blob/datastore" + "github.com/stackrox/rox/generated/storage" +) + +func BlobSnapshot(blobStore datastore.Datastore, name string) (rnc ReadOnlyBlobFile, err error) { + var tempDir string + tempDir, err = os.MkdirTemp("", "blob-file-") + if err != nil { + return + } + defer func() { + if err != nil { + _ = os.RemoveAll(tempDir) + } + }() + + tempFile := filepath.Join(tempDir, "blob.data") + var writer *os.File + writer, err = os.Create(tempFile) + if err != nil { + return + } + + var blob *storage.Blob + blob, _, err = blobStore.Get(context.Background(), name, writer) + if err != nil { + err = errors.Wrapf(err, "failed to open blob with name %q", name) + return + } + err = writer.Close() + if err != nil { + return + } + rnc, err = CreateBlobFile(tempFile, blob) + if err != nil { + return + } + return +} diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index c464fdc7a0610..8a57cda9bc1c9 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -14,6 +14,7 @@ import ( timestamp "github.com/gogo/protobuf/types" "github.com/pkg/errors" + "github.com/stackrox/rox/central/blob/blobfile" blob "github.com/stackrox/rox/central/blob/datastore" "github.com/stackrox/rox/central/cve/fetcher" "github.com/stackrox/rox/central/scannerdefinitions/file" @@ -73,7 +74,6 @@ type httpHandler struct { lock sync.Mutex updaters map[string]*requestedUpdater onlineVulnDir string - offlineFile *file.File blobStore blob.Datastore } @@ -309,6 +309,15 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } } +func (h *httpHandler) openOfflineBlob() (definitionFileReader, time.Time, error) { + snapshot, err := blobfile.BlobSnapshot(h.blobStore, offlineScannerDefsName) + if err != nil { + return nil, time.Time{}, err + } + modTime := *pgutils.NilOrTime(snapshot.GetBlob().ModifiedTime) + return snapshot, modTime, nil +} + // openMostRecentDefinitions opens the latest Scanner Definitions based on // modification time. It's either the one selected by `uuid` if present and // online, otherwise fallback to the manually uploaded definitions. The file @@ -317,7 +326,7 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc definitionFileReader, modTime time.Time, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - rc, modTime, err = h.open() + rc, modTime, err = h.openOfflineBlob() return } @@ -332,7 +341,7 @@ func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc definitionFileR if err != nil { return } - offlineFile, offlineTime, err := h.offlineFile.Open() + offlineFile, offlineTime, err := h.openOfflineBlob() if err != nil { utils.IgnoreError(onlineFile.Close) return @@ -404,40 +413,3 @@ func openFromArchive(archiveFile string, fileName string) (*os.File, error) { } return tmpFile, nil } - -func (h *httpHandler) open() (rnc definitionFileReader, modTime time.Time, err error) { - var tempDir string - tempDir, err = os.MkdirTemp("", "scanner-definitions-") - if err != nil { - return - } - defer func() { - if err != nil { - _ = os.RemoveAll(tempDir) - } - }() - - tempFile := filepath.Join(tempDir, "tempfile.zip") - var writer *os.File - writer, err = os.Create(tempFile) - if err != nil { - return - } - - var blob *storage.Blob - blob, _, err = h.blobStore.Get(context.Background(), scannerDefinitionBlobName, writer) - if err != nil { - err = errors.Wrapf(err, "failed to open offline scanner definition bundle") - return - } - err = writer.Close() - if err != nil { - return - } - rnc, err = fileutils.CreateTempFile(tempDir, scannerDefsSubZipName) - if err != nil { - return - } - modTime = *pgutils.NilOrTime(blob.GetLastUpdated()) - return -} diff --git a/pkg/fileutils/temp_file.go b/pkg/fileutils/temp_file.go deleted file mode 100644 index 4f1d16c1976bd..0000000000000 --- a/pkg/fileutils/temp_file.go +++ /dev/null @@ -1,34 +0,0 @@ -package fileutils - -import ( - "os" - "path" -) - -// CreateTempFile creates a temp dir with a file. The file and its temp dir will be removed on closure. -func CreateTempFile(p string, name string) (*tempFile, error) { - file, err := os.Open(path.Join(p, name)) - if err != nil { - return nil, err - } - return &tempFile{File: file, path: p, name: name}, nil -} - -type tempFile struct { - *os.File - path string - name string -} - -// Close temp file and remove its temp dir. -func (f *tempFile) Close() error { - err := f.File.Close() - if removeErr := os.RemoveAll(f.path); err != nil { - log.Errorf("failed to remove %q: %v", f.path, removeErr) - } - return err -} - -func (f *tempFile) Name() string { - return f.name -} From 44e45b54dac81ba86ab34f4b9a30a50f45bcaa44 Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 10 May 2023 09:04:58 -0700 Subject: [PATCH 12/40] stage --- .../blob/{blobfile => snapshot}/blob_file.go | 22 +++++++++++-------- central/blob/{blobfile => snapshot}/utils.go | 7 +++--- central/scannerdefinitions/handler/handler.go | 4 ++-- 3 files changed, 19 insertions(+), 14 deletions(-) rename central/blob/{blobfile => snapshot}/blob_file.go (53%) rename central/blob/{blobfile => snapshot}/utils.go (77%) diff --git a/central/blob/blobfile/blob_file.go b/central/blob/snapshot/blob_file.go similarity index 53% rename from central/blob/blobfile/blob_file.go rename to central/blob/snapshot/blob_file.go index 94deaeffd7c19..da0eacfc695ca 100644 --- a/central/blob/blobfile/blob_file.go +++ b/central/blob/snapshot/blob_file.go @@ -1,4 +1,4 @@ -package blobfile +package snapshot import ( "io" @@ -13,7 +13,9 @@ var ( log = logging.LoggerForModule() ) -type ReadOnlyBlobFile interface { +// Snapshot contains a Blob with read-only blob data backed by a temp file. +// The temp file will be removed on close. +type Snapshot interface { io.Reader io.Closer io.Seeker @@ -22,23 +24,23 @@ type ReadOnlyBlobFile interface { GetBlob() *storage.Blob } -// CreateBlobFile creates a temp dir with a file. The file and its temp dir will be removed on closure. -func CreateBlobFile(p string, blob *storage.Blob) (*blobFile, error) { +// NewBlobSnapshot creates a temp dir with a file. The file and its temp dir will be removed on closure. +func NewBlobSnapshot(p string, blob *storage.Blob) (Snapshot, error) { file, err := os.Open(p) if err != nil { return nil, err } - return &blobFile{File: file, blob: blob, path: p}, nil + return &snapshot{File: file, blob: blob, path: p}, nil } -type blobFile struct { +type snapshot struct { *os.File blob *storage.Blob path string } // Close temp file and remove its temp dir. -func (f *blobFile) Close() error { +func (f *snapshot) Close() error { err := f.File.Close() if removeErr := os.RemoveAll(f.path); err != nil { log.Errorf("failed to remove %q: %v", f.path, removeErr) @@ -46,15 +48,17 @@ func (f *blobFile) Close() error { return err } -func (f *blobFile) Name() string { +func (f *snapshot) Name() string { name := path.Base(f.blob.GetName()) + // It is definitely not a good practice to have use empty string or all slashes in + // Blob name, but here is the workaround in case that happens. if name == "." || name == "/" { return "noname" } return name } -func (f *blobFile) GetBlob() *storage.Blob { +func (f *snapshot) GetBlob() *storage.Blob { // return f.blob.Clone() return f.blob } diff --git a/central/blob/blobfile/utils.go b/central/blob/snapshot/utils.go similarity index 77% rename from central/blob/blobfile/utils.go rename to central/blob/snapshot/utils.go index 57b94cccdf6f5..0a8992b056c7c 100644 --- a/central/blob/blobfile/utils.go +++ b/central/blob/snapshot/utils.go @@ -1,4 +1,4 @@ -package blobfile +package snapshot import ( "context" @@ -10,7 +10,8 @@ import ( "github.com/stackrox/rox/generated/storage" ) -func BlobSnapshot(blobStore datastore.Datastore, name string) (rnc ReadOnlyBlobFile, err error) { +// TakeBlobSnapshot create a Snapshot for the named blob. +func TakeBlobSnapshot(blobStore datastore.Datastore, name string) (rnc Snapshot, err error) { var tempDir string tempDir, err = os.MkdirTemp("", "blob-file-") if err != nil { @@ -39,7 +40,7 @@ func BlobSnapshot(blobStore datastore.Datastore, name string) (rnc ReadOnlyBlobF if err != nil { return } - rnc, err = CreateBlobFile(tempFile, blob) + rnc, err = NewBlobSnapshot(tempFile, blob) if err != nil { return } diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index 8a57cda9bc1c9..d1bd4af5a0462 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -14,8 +14,8 @@ import ( timestamp "github.com/gogo/protobuf/types" "github.com/pkg/errors" - "github.com/stackrox/rox/central/blob/blobfile" blob "github.com/stackrox/rox/central/blob/datastore" + "github.com/stackrox/rox/central/blob/snapshot" "github.com/stackrox/rox/central/cve/fetcher" "github.com/stackrox/rox/central/scannerdefinitions/file" "github.com/stackrox/rox/generated/storage" @@ -310,7 +310,7 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } func (h *httpHandler) openOfflineBlob() (definitionFileReader, time.Time, error) { - snapshot, err := blobfile.BlobSnapshot(h.blobStore, offlineScannerDefsName) + snapshot, err := snapshot.TakeBlobSnapshot(h.blobStore, offlineScannerDefsName) if err != nil { return nil, time.Time{}, err } From c892183eba68ec60266497218e5591fca26c3e54 Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 10 May 2023 09:48:34 -0700 Subject: [PATCH 13/40] cherry-pick --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index bb7673ca20ce9..d939309681867 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_M),arm64) # TODO(ROX-12064) build these images in the CI pipeline # Currently built on a GCP ARM instance off the rox-ci-image branch "cgorman-custom-arm" - BUILD_IMAGE = quay.io/rhacs-eng/sandbox:apollo-ci-stackrox-build-0.3.56-arm64 + BUILD_IMAGE = quay.io/rhacs-eng/sandbox:apollo-ci-stackrox-build-0.3.58-arm64 endif endif From 007feb13ca1c8aa26522cc5af4c68d8a352d7658 Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 10 May 2023 11:45:39 -0700 Subject: [PATCH 14/40] Use blobstore for scanner definitions --- central/blob/datastore/store/store.go | 2 + central/blob/snapshot/blob_file.go | 64 ---------- central/blob/snapshot/snapshot.go | 94 ++++++++++++++ central/blob/snapshot/snapshot_test.go | 79 ++++++++++++ central/blob/snapshot/utils.go | 48 -------- .../handler/definition_file.go | 11 -- central/scannerdefinitions/handler/handler.go | 58 +++++---- .../handler/handler_test.go | 115 ++++++++++++------ central/scannerdefinitions/handler/options.go | 6 - .../scannerdefinitions/handler/singleton.go | 3 +- .../handler/vul_def_reader.go | 11 ++ 11 files changed, 302 insertions(+), 189 deletions(-) delete mode 100644 central/blob/snapshot/blob_file.go create mode 100644 central/blob/snapshot/snapshot.go create mode 100644 central/blob/snapshot/snapshot_test.go delete mode 100644 central/blob/snapshot/utils.go delete mode 100644 central/scannerdefinitions/handler/definition_file.go create mode 100644 central/scannerdefinitions/handler/vul_def_reader.go diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index 6e94cbb00a667..d8fed9f6ff0b9 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -64,6 +64,7 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea if err := lo.Truncate(0); err != nil { return wrapRollback(ctx, tx, errors.Wrapf(err, "truncating blob with oid %d", existingBlob.GetOid())) } + obj.Oid = existingBlob.GetOid() } else { oid, err := los.Create(ctx, 0) if err != nil { @@ -138,6 +139,7 @@ func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*st if err == io.EOF { break } + return nil, false, wrapRollback(ctx, tx, errors.Wrap(err, "reading blob")) } } if err := lo.Close(); err != nil { diff --git a/central/blob/snapshot/blob_file.go b/central/blob/snapshot/blob_file.go deleted file mode 100644 index da0eacfc695ca..0000000000000 --- a/central/blob/snapshot/blob_file.go +++ /dev/null @@ -1,64 +0,0 @@ -package snapshot - -import ( - "io" - "os" - "path" - - "github.com/stackrox/rox/generated/storage" - "github.com/stackrox/rox/pkg/logging" -) - -var ( - log = logging.LoggerForModule() -) - -// Snapshot contains a Blob with read-only blob data backed by a temp file. -// The temp file will be removed on close. -type Snapshot interface { - io.Reader - io.Closer - io.Seeker - Name() string - - GetBlob() *storage.Blob -} - -// NewBlobSnapshot creates a temp dir with a file. The file and its temp dir will be removed on closure. -func NewBlobSnapshot(p string, blob *storage.Blob) (Snapshot, error) { - file, err := os.Open(p) - if err != nil { - return nil, err - } - return &snapshot{File: file, blob: blob, path: p}, nil -} - -type snapshot struct { - *os.File - blob *storage.Blob - path string -} - -// Close temp file and remove its temp dir. -func (f *snapshot) Close() error { - err := f.File.Close() - if removeErr := os.RemoveAll(f.path); err != nil { - log.Errorf("failed to remove %q: %v", f.path, removeErr) - } - return err -} - -func (f *snapshot) Name() string { - name := path.Base(f.blob.GetName()) - // It is definitely not a good practice to have use empty string or all slashes in - // Blob name, but here is the workaround in case that happens. - if name == "." || name == "/" { - return "noname" - } - return name -} - -func (f *snapshot) GetBlob() *storage.Blob { - // return f.blob.Clone() - return f.blob -} diff --git a/central/blob/snapshot/snapshot.go b/central/blob/snapshot/snapshot.go new file mode 100644 index 0000000000000..bf52813ba361a --- /dev/null +++ b/central/blob/snapshot/snapshot.go @@ -0,0 +1,94 @@ +package snapshot + +import ( + "context" + "os" + "path" + "path/filepath" + + "github.com/pkg/errors" + "github.com/stackrox/rox/central/blob/datastore" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/logging" +) + +var ( + log = logging.LoggerForModule() + ErrBlobNotExist = errors.New("cannot find blob") +) + +// NewBlobSnapshot creates a snapshot of the blob backed up by a temp dir with the blob data stored. +// The temp dir is removed on snapshot closure. +func NewBlobSnapshot(p string, blob *storage.Blob) (*Snapshot, error) { + file, err := os.Open(p) + if err != nil { + return nil, err + } + return &Snapshot{File: file, blob: blob, path: p}, nil +} + +// Snapshot contains a Blob with its data backed by a temp file. +// The temp file will be removed on close. +type Snapshot struct { + *os.File + blob *storage.Blob + path string +} + +// Close temp file and remove its temp dir. +func (s *Snapshot) Close() error { + if s == nil || s.File == nil { + return nil + } + err := s.File.Close() + if removeErr := os.RemoveAll(s.path); err != nil { + log.Errorf("failed to remove %q: %v", s.path, removeErr) + } + return err +} + +func (s *Snapshot) GetBlob() *storage.Blob { + return s.blob +} + +// TakeBlobSnapshot create a Snapshot for the named blob if it exists +func TakeBlobSnapshot(ctx context.Context, blobStore datastore.Datastore, name string) (rnc *Snapshot, err error) { + var tempDir string + tempDir, err = os.MkdirTemp("", "blob-file-") + if err != nil { + return + } + defer func() { + if err != nil { + _ = os.RemoveAll(tempDir) + } + }() + + baseName := path.Base(name) + // It is definitely not a good practice to have use empty string or all slashes in + // Blob name, but here is the workaround in case that happens. + if baseName == "." || baseName == "/" { + baseName = "noname" + } + tempFile := filepath.Join(tempDir, baseName) + + var writer *os.File + if writer, err = os.Create(tempFile); err != nil { + return + } + + var blob *storage.Blob + var exists bool + if blob, exists, err = blobStore.Get(ctx, name, writer); err != nil { + err = errors.Wrapf(err, "failed to open blob with name %q", name) + return + } + if err = writer.Close(); err != nil { + return + } + if !exists { + err = ErrBlobNotExist + return + } + return NewBlobSnapshot(tempFile, blob) +} diff --git a/central/blob/snapshot/snapshot_test.go b/central/blob/snapshot/snapshot_test.go new file mode 100644 index 0000000000000..eacc12cebb9cf --- /dev/null +++ b/central/blob/snapshot/snapshot_test.go @@ -0,0 +1,79 @@ +//go:build sql_integration + +package snapshot + +import ( + "bytes" + "context" + "io" + "math/rand" + "testing" + + timestamp "github.com/gogo/protobuf/types" + "github.com/stackrox/rox/central/blob/datastore" + "github.com/stackrox/rox/central/blob/datastore/store" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/postgres/pgtest" + "github.com/stackrox/rox/pkg/sac" + "github.com/stretchr/testify/suite" +) + +type snapshotTestSuite struct { + suite.Suite + ctx context.Context + store store.Store + datastore datastore.Datastore + testDB *pgtest.TestPostgres +} + +func TestBlobsStore(t *testing.T) { + suite.Run(t, new(snapshotTestSuite)) +} + +func (s *snapshotTestSuite) SetupSuite() { + s.ctx = sac.WithAllAccess(context.Background()) + s.testDB = pgtest.ForT(s.T()) + s.store = store.New(s.testDB.DB) + s.datastore = datastore.NewDatastore(s.store) +} + +func (s *snapshotTestSuite) SetupTest() { + tag, err := s.testDB.Exec(s.ctx, "TRUNCATE blobs CASCADE") + s.T().Log("blobs", tag) + s.NoError(err) +} + +func (s *snapshotTestSuite) TearDownSuite() { + s.testDB.Teardown(s.T()) +} + +func (s *snapshotTestSuite) TestSnapshot() { + ctx := sac.WithAllAccess(context.Background()) + size := 1024*1024 + 16 + insertBlob := &storage.Blob{ + Name: "test", + LastUpdated: timestamp.TimestampNow(), + ModifiedTime: timestamp.TimestampNow(), + Length: int64(size), + } + + randomData := make([]byte, size) + _, err := rand.Read(randomData) + s.NoError(err) + + reader := bytes.NewBuffer(randomData) + + s.Require().NoError(s.store.Upsert(ctx, insertBlob, reader)) + + snap, err := TakeBlobSnapshot(ctx, s.datastore, insertBlob.GetName()) + s.NoError(err) + defer func() { + s.NoError(snap.Close()) + s.NoFileExists(snap.path) + }() + bytes, err := io.ReadAll(snap) + s.Require().NoError(err) + s.Equal(randomData, bytes) + s.Equal(insertBlob, snap.GetBlob()) + s.FileExists(snap.path) +} diff --git a/central/blob/snapshot/utils.go b/central/blob/snapshot/utils.go deleted file mode 100644 index 0a8992b056c7c..0000000000000 --- a/central/blob/snapshot/utils.go +++ /dev/null @@ -1,48 +0,0 @@ -package snapshot - -import ( - "context" - "os" - "path/filepath" - - "github.com/pkg/errors" - "github.com/stackrox/rox/central/blob/datastore" - "github.com/stackrox/rox/generated/storage" -) - -// TakeBlobSnapshot create a Snapshot for the named blob. -func TakeBlobSnapshot(blobStore datastore.Datastore, name string) (rnc Snapshot, err error) { - var tempDir string - tempDir, err = os.MkdirTemp("", "blob-file-") - if err != nil { - return - } - defer func() { - if err != nil { - _ = os.RemoveAll(tempDir) - } - }() - - tempFile := filepath.Join(tempDir, "blob.data") - var writer *os.File - writer, err = os.Create(tempFile) - if err != nil { - return - } - - var blob *storage.Blob - blob, _, err = blobStore.Get(context.Background(), name, writer) - if err != nil { - err = errors.Wrapf(err, "failed to open blob with name %q", name) - return - } - err = writer.Close() - if err != nil { - return - } - rnc, err = NewBlobSnapshot(tempFile, blob) - if err != nil { - return - } - return -} diff --git a/central/scannerdefinitions/handler/definition_file.go b/central/scannerdefinitions/handler/definition_file.go deleted file mode 100644 index ac4e0ce84bf09..0000000000000 --- a/central/scannerdefinitions/handler/definition_file.go +++ /dev/null @@ -1,11 +0,0 @@ -package handler - -import "io" - -// definitionFileReader unifies online and offline reader for vuln definitions -type definitionFileReader interface { - io.Reader - io.Closer - io.Seeker - Name() string -} diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index d1bd4af5a0462..2ca5ca2fd037b 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -20,6 +20,7 @@ import ( "github.com/stackrox/rox/central/scannerdefinitions/file" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/env" + "github.com/stackrox/rox/pkg/errorhelpers" "github.com/stackrox/rox/pkg/fileutils" "github.com/stackrox/rox/pkg/httputil" "github.com/stackrox/rox/pkg/httputil/proxy" @@ -78,13 +79,13 @@ type httpHandler struct { } // New creates a new http.Handler to handle vulnerability data. -func New(cveManager fetcher.OrchestratorIstioCVEManager, opts handlerOpts) http.Handler { +func New(cveManager fetcher.OrchestratorIstioCVEManager, blobStore blob.Datastore, opts handlerOpts) http.Handler { h := &httpHandler{ cveManager: cveManager, online: !env.OfflineModeEnv.BooleanSetting(), interval: env.ScannerVulnUpdateInterval.DurationSetting(), - blobStore: blob.Singleton(), + blobStore: blobStore, } if h.online { @@ -119,7 +120,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *httpHandler) get(w http.ResponseWriter, r *http.Request) { // Open the most recent definitions file for the provided `uuid`. uuid := r.URL.Query().Get(`uuid`) - f, modTime, err := h.openMostRecentDefinitions(uuid) + f, modTime, err := h.openMostRecentDefinitions(r.Context(), uuid) if err != nil { writeErrorForFile(w, err, uuid) return @@ -156,12 +157,12 @@ func writeErrorNotFound(w http.ResponseWriter) { } func writeErrorForFile(w http.ResponseWriter, err error, path string) { - if errors.Is(err, fs.ErrNotExist) { + if errorhelpers.IsAny(err, fs.ErrNotExist, snapshot.ErrBlobNotExist) { writeErrorNotFound(w) return } - httputil.WriteGRPCStyleErrorf(w, codes.Internal, "could not read file %s: %v", filepath.Base(path), err) + httputil.WriteGRPCStyleErrorf(w, codes.Internal, "could not read vulnerability definition %s: %v", filepath.Base(path), err) } func serveContent(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, content io.ReadSeeker) { @@ -203,7 +204,7 @@ func (h *httpHandler) updateK8sIstioCVEs(zipPath string) { } } -func (h *httpHandler) handleScannerDefsFile(zipF *zip.File) error { +func (h *httpHandler) handleScannerDefsFile(zipF *zip.File, ctx context.Context) error { r, err := zipF.Open() if err != nil { return errors.Wrap(err, "opening ZIP reader") @@ -218,14 +219,14 @@ func (h *httpHandler) handleScannerDefsFile(zipF *zip.File) error { Length: zipF.FileInfo().Size(), } - if err := h.blobStore.Upsert(context.Background(), b, r); err != nil { + if err := h.blobStore.Upsert(ctx, b, r); err != nil { return errors.Wrap(err, "writing scanner definitions") } return nil } -func (h *httpHandler) handleZipContentsFromVulnDump(zipPath string) error { +func (h *httpHandler) handleZipContentsFromVulnDump(ctx context.Context, zipPath string) error { zipR, err := zip.OpenReader(zipPath) if err != nil { return errors.Wrap(err, "couldn't open file as zip") @@ -235,7 +236,7 @@ func (h *httpHandler) handleZipContentsFromVulnDump(zipPath string) error { var scannerDefsFileFound bool for _, zipF := range zipR.File { if zipF.Name == scannerDefsSubZipName { - if err := h.handleScannerDefsFile(zipF); err != nil { + if err := h.handleScannerDefsFile(zipF, ctx); err != nil { return errors.Wrap(err, "couldn't handle scanner-defs sub file") } scannerDefsFileFound = true @@ -269,7 +270,7 @@ func (h *httpHandler) post(w http.ResponseWriter, r *http.Request) { return } - if err := h.handleZipContentsFromVulnDump(tempFile); err != nil { + if err := h.handleZipContentsFromVulnDump(r.Context(), tempFile); err != nil { httputil.WriteGRPCStyleError(w, codes.Internal, err) return } @@ -309,13 +310,20 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } } -func (h *httpHandler) openOfflineBlob() (definitionFileReader, time.Time, error) { - snapshot, err := snapshot.TakeBlobSnapshot(h.blobStore, offlineScannerDefsName) +func (h *httpHandler) openOfflineBlob(ctx context.Context) (*os.File, time.Time, error) { + snap, err := snapshot.TakeBlobSnapshot(ctx, h.blobStore, offlineScannerDefsName) if err != nil { + // If the blob does not exist, return no reader. + if errors.Is(err, snapshot.ErrBlobNotExist) { + return nil, time.Time{}, nil + } return nil, time.Time{}, err } - modTime := *pgutils.NilOrTime(snapshot.GetBlob().ModifiedTime) - return snapshot, modTime, nil + modTime := time.Time{} + if t := pgutils.NilOrTime(snap.GetBlob().ModifiedTime); t != nil { + modTime = *t + } + return snap.File, modTime, nil } // openMostRecentDefinitions opens the latest Scanner Definitions based on @@ -323,10 +331,10 @@ func (h *httpHandler) openOfflineBlob() (definitionFileReader, time.Time, error) // online, otherwise fallback to the manually uploaded definitions. The file // object can be `nil` if the definitions file does not exist, rather than // returning an error. -func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc definitionFileReader, modTime time.Time, err error) { +func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string) (file *os.File, modTime time.Time, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - rc, modTime, err = h.openOfflineBlob() + file, modTime, err = h.openOfflineBlob(ctx) return } @@ -336,28 +344,32 @@ func (h *httpHandler) openMostRecentDefinitions(uuid string) (rc definitionFileR u := h.getUpdater(uuid) u.Start() + toClose := func(f *os.File) { + if file != f && f != nil { + utils.IgnoreError(f.Close) + } + } + // Open both the "online" and "offline", and save their modification times. onlineFile, onlineTime, err := u.file.Open() if err != nil { return } - offlineFile, offlineTime, err := h.openOfflineBlob() + defer toClose(onlineFile) + offlineFile, offlineTime, err := h.openOfflineBlob(ctx) if err != nil { - utils.IgnoreError(onlineFile.Close) return } + defer toClose(offlineFile) // Return the most recent file, notice that if both don't exist, nil is returned // since modification time will be zero. if offlineTime.After(onlineTime) { - rc, modTime = offlineFile, offlineTime - utils.IgnoreError(onlineFile.Close) + file, modTime = offlineFile, offlineTime } else { - rc, modTime = onlineFile, onlineTime - utils.IgnoreError(offlineFile.Close) + file, modTime = onlineFile, onlineTime } - return } diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index c6d3b15c712bd..ce8a04bd29826 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -1,118 +1,151 @@ +//go:build sql_integration + package handler import ( + "bytes" + "context" "fmt" "net/http" "os" - "path/filepath" "testing" "time" + "github.com/gogo/protobuf/types" + "github.com/stackrox/rox/central/blob/datastore" + "github.com/stackrox/rox/central/blob/datastore/store" + "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/env" "github.com/stackrox/rox/pkg/httputil/mock" + "github.com/stackrox/rox/pkg/postgres/pgtest" + "github.com/stackrox/rox/pkg/sac" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +const ( + content1 = "Hello, world!" + content2 = "Papaya" ) -func mustGetRequest(t *testing.T) *http.Request { +type handlerTestSuite struct { + suite.Suite + ctx context.Context + datastore datastore.Datastore + testDB *pgtest.TestPostgres +} + +func TestHandler(t *testing.T) { + suite.Run(t, new(handlerTestSuite)) +} + +func (s *handlerTestSuite) SetupSuite() { + s.ctx = sac.WithAllAccess(context.Background()) + s.testDB = pgtest.ForT(s.T()) + blobStore := store.New(s.testDB.DB) + s.datastore = datastore.NewDatastore(blobStore) +} + +func (s *handlerTestSuite) SetupTest() { + tag, err := s.testDB.Exec(s.ctx, "TRUNCATE blobs CASCADE") + s.T().Log("blobs", tag) + s.NoError(err) +} + +func (s *handlerTestSuite) TearDownSuite() { + s.testDB.Teardown(s.T()) +} + +func (s *handlerTestSuite) mustGetRequest(t *testing.T) *http.Request { centralURL := "https://central.stackrox.svc/scannerdefinitions?uuid=e799c68a-671f-44db-9682-f24248cd0ffe" - req, err := http.NewRequest(http.MethodGet, centralURL, nil) + req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, centralURL, nil) + require.NoError(t, err) return req } -func mustGetRequestWithFile(t *testing.T, file string) *http.Request { +func (s *handlerTestSuite) mustGetRequestWithFile(t *testing.T, file string) *http.Request { centralURL := fmt.Sprintf("https://central.stackrox.svc/scannerdefinitions?uuid=e799c68a-671f-44db-9682-f24248cd0ffe&file=%s", file) - req, err := http.NewRequest(http.MethodGet, centralURL, nil) + req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, centralURL, nil) require.NoError(t, err) return req } -func mustGetBadRequest(t *testing.T) *http.Request { +func (s *handlerTestSuite) mustGetBadRequest(t *testing.T) *http.Request { centralURL := "https://central.stackrox.svc/scannerdefinitions?uuid=fail" - req, err := http.NewRequest(http.MethodGet, centralURL, nil) + req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, centralURL, nil) require.NoError(t, err) return req } -func TestServeHTTP_Offline_Get(t *testing.T) { +func (s *handlerTestSuite) TestServeHTTP_Offline_Get() { + t := s.T() t.Setenv(env.OfflineModeEnv.EnvVar(), "true") - tmpDir := t.TempDir() - h := New(nil, handlerOpts{ - offlineVulnDefsDir: tmpDir, - }) + h := New(nil, s.datastore, handlerOpts{}) // No scanner defs found. - req := mustGetRequest(t) + req := s.mustGetRequest(t) w := mock.NewResponseWriter() h.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) // Add scanner defs. - f, err := os.Create(filepath.Join(tmpDir, offlineScannerDefsName)) - require.NoError(t, err) - _, err = f.Write([]byte("Hello, World!")) - require.NoError(t, err) + s.mustWriteOffline(content1, time.Now()) w.Data.Reset() h.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "Hello, World!", w.Data.String()) + assert.Equal(t, content1, w.Data.String()) } -func TestServeHTTP_Online_Get(t *testing.T) { - tmpDir := t.TempDir() - h := New(nil, handlerOpts{ - offlineVulnDefsDir: tmpDir, - }) +func (s *handlerTestSuite) TestServeHTTP_Online_Get() { + t := s.T() + h := New(nil, s.datastore, handlerOpts{}) w := mock.NewResponseWriter() // Should not get anything. - req := mustGetBadRequest(t) + req := s.mustGetBadRequest(t) h.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) // Should get file from online update. - req = mustGetRequestWithFile(t, "manifest.json") + req = s.mustGetRequestWithFile(t, "manifest.json") w.Data.Reset() h.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "application/json", w.Header().Get("Content-Type")) - assert.Regexpf(t, `{"since":".*","until":".*"}`, w.Data.String(), "content did not match") + assert.Regexpf(t, `{"since":".*","until":".*"}`, w.Data.String(), "content1 did not match") // Should get online update. - req = mustGetRequest(t) + req = s.mustGetRequest(t) h.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) // Write offline definitions. - f, err := os.Create(filepath.Join(tmpDir, offlineScannerDefsName)) - require.NoError(t, err) - _, err = f.Write([]byte("Hello, World!")) - require.NoError(t, err) + s.mustWriteOffline(content1, time.Now()) // Set the offline dump's modified time to later than the online update's. - handler := h.(*httpHandler) - mustSetModTime(t, handler.offlineFile.Path(), time.Now().Add(time.Minute)) + s.mustWriteOffline(content1, time.Now().Add(time.Hour)) // Served the offline dump, as it is more recent. w.Data.Reset() h.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "Hello, World!", w.Data.String()) + assert.Equal(t, content1, w.Data.String()) // Set the offline dump's modified time to earlier than the online update's. - mustSetModTime(t, handler.offlineFile.Path(), nov23) + s.mustWriteOffline(content2, nov23) // Serve the online dump, as it is now more recent. w.Data.Reset() h.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.NotEqual(t, "Hello, World!", w.Data.String()) + assert.NotEqual(t, content2, w.Data.String()) // File is unmodified. req.Header.Set(ifModifiedSinceHeader, time.Now().UTC().Format(http.TimeFormat)) @@ -125,3 +158,13 @@ func TestServeHTTP_Online_Get(t *testing.T) { func mustSetModTime(t *testing.T, path string, modTime time.Time) { require.NoError(t, os.Chtimes(path, time.Now(), modTime)) } + +func (s *handlerTestSuite) mustWriteOffline(content string, modTime time.Time) { + modifiedTime, err := types.TimestampProto(modTime) + s.NoError(err) + blob := &storage.Blob{ + Name: offlineScannerDefsName, + ModifiedTime: modifiedTime, + } + s.Require().NoError(s.datastore.Upsert(s.ctx, blob, bytes.NewBuffer([]byte(content)))) +} diff --git a/central/scannerdefinitions/handler/options.go b/central/scannerdefinitions/handler/options.go index f42733f7f9beb..148814833e101 100644 --- a/central/scannerdefinitions/handler/options.go +++ b/central/scannerdefinitions/handler/options.go @@ -4,13 +4,7 @@ import "time" // handlerOpts represents the options for a scannerdefinitions http.Handler. type handlerOpts struct { - // offlineVulnDefsDir is the directory in which persisted vulnerability definitions should be written. - // It is assumed the directory already exists. - // Default: /var/lib/stackrox/scannerdefinitions - offlineVulnDefsDir string - // The following are options which are only respected in online-mode. - // cleanupInterval sets the interval for cleaning up updaters. cleanupInterval *time.Duration // cleanupAge sets the age after which an updater should be cleaned. diff --git a/central/scannerdefinitions/handler/singleton.go b/central/scannerdefinitions/handler/singleton.go index 9667f32865871..c5d22f208a8bd 100644 --- a/central/scannerdefinitions/handler/singleton.go +++ b/central/scannerdefinitions/handler/singleton.go @@ -3,6 +3,7 @@ package handler import ( "net/http" + blob "github.com/stackrox/rox/central/blob/datastore" "github.com/stackrox/rox/central/cve/fetcher" "github.com/stackrox/rox/pkg/sync" ) @@ -15,7 +16,7 @@ var ( // Singleton returns the singleton service handler. func Singleton() http.Handler { once.Do(func() { - singleton = New(fetcher.SingletonManager(), handlerOpts{}) + singleton = New(fetcher.SingletonManager(), blob.Singleton(), handlerOpts{}) }) return singleton } diff --git a/central/scannerdefinitions/handler/vul_def_reader.go b/central/scannerdefinitions/handler/vul_def_reader.go new file mode 100644 index 0000000000000..9f9ec11940a03 --- /dev/null +++ b/central/scannerdefinitions/handler/vul_def_reader.go @@ -0,0 +1,11 @@ +package handler + +import "io" + +// vulDefReader unifies online and offline reader for vuln definitions +type vulDefReader interface { + io.Reader + io.Closer + io.Seeker + Name() string +} From 1e36fc332a39e6f207a8d0c38debad565fef3c53 Mon Sep 17 00:00:00 2001 From: cdu Date: Thu, 11 May 2023 01:56:20 -0700 Subject: [PATCH 15/40] stage --- central/blob/snapshot/snapshot.go | 31 +++++----- central/blob/snapshot/snapshot_test.go | 4 +- central/scannerdefinitions/handler/handler.go | 59 ++++++++++--------- .../handler/handler_test.go | 15 ++++- .../handler/scanner_definition.go | 20 +++++++ .../handler/vul_def_reader.go | 11 ---- 6 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 central/scannerdefinitions/handler/scanner_definition.go delete mode 100644 central/scannerdefinitions/handler/vul_def_reader.go diff --git a/central/blob/snapshot/snapshot.go b/central/blob/snapshot/snapshot.go index bf52813ba361a..2bd5dce3f9756 100644 --- a/central/blob/snapshot/snapshot.go +++ b/central/blob/snapshot/snapshot.go @@ -13,26 +13,19 @@ import ( ) var ( - log = logging.LoggerForModule() + log = logging.LoggerForModule() + + // ErrBlobNotExist is Blob does not exist error ErrBlobNotExist = errors.New("cannot find blob") ) -// NewBlobSnapshot creates a snapshot of the blob backed up by a temp dir with the blob data stored. -// The temp dir is removed on snapshot closure. -func NewBlobSnapshot(p string, blob *storage.Blob) (*Snapshot, error) { - file, err := os.Open(p) - if err != nil { - return nil, err - } - return &Snapshot{File: file, blob: blob, path: p}, nil -} - // Snapshot contains a Blob with its data backed by a temp file. // The temp file will be removed on close. type Snapshot struct { *os.File - blob *storage.Blob - path string + blob *storage.Blob + tmpDir string + baseName string } // Close temp file and remove its temp dir. @@ -41,12 +34,13 @@ func (s *Snapshot) Close() error { return nil } err := s.File.Close() - if removeErr := os.RemoveAll(s.path); err != nil { - log.Errorf("failed to remove %q: %v", s.path, removeErr) + if removeErr := os.RemoveAll(s.tmpDir); removeErr != nil { + log.Errorf("failed to remove %q: %v", s.tmpDir, removeErr) } return err } +// GetBlob returns Blob func (s *Snapshot) GetBlob() *storage.Blob { return s.blob } @@ -90,5 +84,10 @@ func TakeBlobSnapshot(ctx context.Context, blobStore datastore.Datastore, name s err = ErrBlobNotExist return } - return NewBlobSnapshot(tempFile, blob) + + file, err := os.Open(tempFile) + if err != nil { + return nil, err + } + return &Snapshot{File: file, blob: blob, tmpDir: tempDir, baseName: baseName}, nil } diff --git a/central/blob/snapshot/snapshot_test.go b/central/blob/snapshot/snapshot_test.go index eacc12cebb9cf..e57277248404b 100644 --- a/central/blob/snapshot/snapshot_test.go +++ b/central/blob/snapshot/snapshot_test.go @@ -69,11 +69,11 @@ func (s *snapshotTestSuite) TestSnapshot() { s.NoError(err) defer func() { s.NoError(snap.Close()) - s.NoFileExists(snap.path) + s.NoFileExists(snap.Name()) }() bytes, err := io.ReadAll(snap) s.Require().NoError(err) s.Equal(randomData, bytes) s.Equal(insertBlob, snap.GetBlob()) - s.FileExists(snap.path) + s.FileExists(snap.Name()) } diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index 2ca5ca2fd037b..7e6a33db5d188 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -120,7 +120,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *httpHandler) get(w http.ResponseWriter, r *http.Request) { // Open the most recent definitions file for the provided `uuid`. uuid := r.URL.Query().Get(`uuid`) - f, modTime, err := h.openMostRecentDefinitions(r.Context(), uuid) + f, err := h.openMostRecentDefinitions(r.Context(), uuid) if err != nil { writeErrorForFile(w, err, uuid) return @@ -136,19 +136,21 @@ func (h *httpHandler) get(w http.ResponseWriter, r *http.Request) { defer utils.IgnoreError(f.Close) - // If `file` was provided, extract from definitions' bundle to a - // temporary file and serve that instead. fileName := r.URL.Query().Get(`file`) - if fileName != "" { - f, err = openFromArchive(f.Name(), fileName) - if err != nil { - writeErrorForFile(w, err, fileName) - return - } - defer utils.IgnoreError(f.Close) + if fileName == "" { + serveContent(w, r, f.Name(), f.modTime, f) + return } - serveContent(w, r, f.Name(), modTime, f) + // If `file` was provided, extract from definitions' bundle to a + // temporary file and serve that instead. + namedFile, err := openFromArchive(f.Name(), fileName) + if err != nil { + writeErrorForFile(w, err, fileName) + return + } + defer utils.IgnoreError(namedFile.Close) + serveContent(w, r, namedFile.Name(), f.modTime, namedFile) } func writeErrorNotFound(w http.ResponseWriter) { @@ -204,7 +206,7 @@ func (h *httpHandler) updateK8sIstioCVEs(zipPath string) { } } -func (h *httpHandler) handleScannerDefsFile(zipF *zip.File, ctx context.Context) error { +func (h *httpHandler) handleScannerDefsFile(ctx context.Context, zipF *zip.File) error { r, err := zipF.Open() if err != nil { return errors.Wrap(err, "opening ZIP reader") @@ -236,7 +238,7 @@ func (h *httpHandler) handleZipContentsFromVulnDump(ctx context.Context, zipPath var scannerDefsFileFound bool for _, zipF := range zipR.File { if zipF.Name == scannerDefsSubZipName { - if err := h.handleScannerDefsFile(zipF, ctx); err != nil { + if err := h.handleScannerDefsFile(ctx, zipF); err != nil { return errors.Wrap(err, "couldn't handle scanner-defs sub file") } scannerDefsFileFound = true @@ -310,20 +312,20 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } } -func (h *httpHandler) openOfflineBlob(ctx context.Context) (*os.File, time.Time, error) { +func (h *httpHandler) openOfflineBlob(ctx context.Context) (*vulDefFile, error) { snap, err := snapshot.TakeBlobSnapshot(ctx, h.blobStore, offlineScannerDefsName) if err != nil { // If the blob does not exist, return no reader. if errors.Is(err, snapshot.ErrBlobNotExist) { - return nil, time.Time{}, nil + return nil, nil } - return nil, time.Time{}, err + return nil, err } modTime := time.Time{} if t := pgutils.NilOrTime(snap.GetBlob().ModifiedTime); t != nil { modTime = *t } - return snap.File, modTime, nil + return &vulDefFile{snap.File, modTime, snap.Close}, nil } // openMostRecentDefinitions opens the latest Scanner Definitions based on @@ -331,10 +333,10 @@ func (h *httpHandler) openOfflineBlob(ctx context.Context) (*os.File, time.Time, // online, otherwise fallback to the manually uploaded definitions. The file // object can be `nil` if the definitions file does not exist, rather than // returning an error. -func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string) (file *os.File, modTime time.Time, err error) { +func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string) (file *vulDefFile, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - file, modTime, err = h.openOfflineBlob(ctx) + file, err = h.openOfflineBlob(ctx) return } @@ -344,19 +346,24 @@ func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string u := h.getUpdater(uuid) u.Start() - toClose := func(f *os.File) { + toClose := func(f *vulDefFile) { if file != f && f != nil { utils.IgnoreError(f.Close) } } // Open both the "online" and "offline", and save their modification times. - onlineFile, onlineTime, err := u.file.Open() + var onlineFile *vulDefFile + onlineOsFile, onlineTime, err := u.file.Open() if err != nil { return } + if onlineOsFile != nil { + onlineFile = &vulDefFile{File: onlineOsFile, modTime: onlineTime} + } + defer toClose(onlineFile) - offlineFile, offlineTime, err := h.openOfflineBlob(ctx) + offlineFile, err := h.openOfflineBlob(ctx) if err != nil { return } @@ -364,11 +371,9 @@ func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string // Return the most recent file, notice that if both don't exist, nil is returned // since modification time will be zero. - - if offlineTime.After(onlineTime) { - file, modTime = offlineFile, offlineTime - } else { - file, modTime = onlineFile, onlineTime + file = onlineFile + if offlineFile != nil && offlineFile.modTime.After(onlineTime) { + file = offlineFile } return } diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index ce8a04bd29826..4553eedb53751 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -1,5 +1,3 @@ -//go:build sql_integration - package handler import ( @@ -8,6 +6,7 @@ import ( "fmt" "net/http" "os" + "strings" "testing" "time" @@ -19,6 +18,7 @@ import ( "github.com/stackrox/rox/pkg/httputil/mock" "github.com/stackrox/rox/pkg/postgres/pgtest" "github.com/stackrox/rox/pkg/sac" + "github.com/stackrox/rox/pkg/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -34,6 +34,7 @@ type handlerTestSuite struct { ctx context.Context datastore datastore.Datastore testDB *pgtest.TestPostgres + tmpDir string } func TestHandler(t *testing.T) { @@ -45,6 +46,10 @@ func (s *handlerTestSuite) SetupSuite() { s.testDB = pgtest.ForT(s.T()) blobStore := store.New(s.testDB.DB) s.datastore = datastore.NewDatastore(blobStore) + var err error + s.tmpDir, err = os.MkdirTemp("", "handler-test") + s.Require().NoError(err) + s.T().Setenv("TMPDIR", s.tmpDir) } func (s *handlerTestSuite) SetupTest() { @@ -54,7 +59,13 @@ func (s *handlerTestSuite) SetupTest() { } func (s *handlerTestSuite) TearDownSuite() { + entries, err := os.ReadDir(s.tmpDir) + s.NoError(err) + s.Len(entries, 1) + s.True(strings.HasPrefix(entries[0].Name(), definitionsBaseDir)) + s.testDB.Teardown(s.T()) + utils.IgnoreError(func() error { return os.RemoveAll(s.tmpDir) }) } func (s *handlerTestSuite) mustGetRequest(t *testing.T) *http.Request { diff --git a/central/scannerdefinitions/handler/scanner_definition.go b/central/scannerdefinitions/handler/scanner_definition.go new file mode 100644 index 0000000000000..dae3fa416a245 --- /dev/null +++ b/central/scannerdefinitions/handler/scanner_definition.go @@ -0,0 +1,20 @@ +package handler + +import ( + "os" + "time" +) + +// vulDefFile unifies online and offline reader for scanner definitions +type vulDefFile struct { + *os.File + modTime time.Time + closer func() error +} + +func (f *vulDefFile) Close() error { + if f.closer != nil { + return f.closer() + } + return f.File.Close() +} diff --git a/central/scannerdefinitions/handler/vul_def_reader.go b/central/scannerdefinitions/handler/vul_def_reader.go deleted file mode 100644 index 9f9ec11940a03..0000000000000 --- a/central/scannerdefinitions/handler/vul_def_reader.go +++ /dev/null @@ -1,11 +0,0 @@ -package handler - -import "io" - -// vulDefReader unifies online and offline reader for vuln definitions -type vulDefReader interface { - io.Reader - io.Closer - io.Seeker - Name() string -} From 1312343b27d44cff242cf7c6b039a21779b9d4dd Mon Sep 17 00:00:00 2001 From: cdu Date: Thu, 11 May 2023 11:20:40 -0700 Subject: [PATCH 16/40] first break it --- central/probeupload/manager/manager_impl.go | 1 - 1 file changed, 1 deletion(-) diff --git a/central/probeupload/manager/manager_impl.go b/central/probeupload/manager/manager_impl.go index 6462785aa042a..da22f0b6a55d2 100644 --- a/central/probeupload/manager/manager_impl.go +++ b/central/probeupload/manager/manager_impl.go @@ -49,7 +49,6 @@ type manager struct { func newManager(persistenceRoot string) *manager { return &manager{ - rootDir: filepath.Join(persistenceRoot, rootDirName), freeDiskThreshold: defaultFreeDiskThreshold, } } From e57c2f7c41e0cce2fc7d17b0c7d27d9c97da9247 Mon Sep 17 00:00:00 2001 From: cdu Date: Mon, 15 May 2023 11:38:28 -0700 Subject: [PATCH 17/40] security --- central/scannerdefinitions/handler/handler.go | 10 ++++++---- central/scannerdefinitions/handler/handler_test.go | 6 ++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index 7e6a33db5d188..c78e28e2e5f27 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -26,6 +26,7 @@ import ( "github.com/stackrox/rox/pkg/httputil/proxy" "github.com/stackrox/rox/pkg/logging" "github.com/stackrox/rox/pkg/postgres/pgutils" + "github.com/stackrox/rox/pkg/sac" "github.com/stackrox/rox/pkg/sync" "github.com/stackrox/rox/pkg/utils" "google.golang.org/grpc/codes" @@ -221,7 +222,7 @@ func (h *httpHandler) handleScannerDefsFile(ctx context.Context, zipF *zip.File) Length: zipF.FileInfo().Size(), } - if err := h.blobStore.Upsert(ctx, b, r); err != nil { + if err := h.blobStore.Upsert(sac.WithAllAccess(ctx), b, r); err != nil { return errors.Wrap(err, "writing scanner definitions") } @@ -313,12 +314,14 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } func (h *httpHandler) openOfflineBlob(ctx context.Context) (*vulDefFile, error) { - snap, err := snapshot.TakeBlobSnapshot(ctx, h.blobStore, offlineScannerDefsName) + snap, err := snapshot.TakeBlobSnapshot(sac.WithAllAccess(ctx), h.blobStore, offlineScannerDefsName) if err != nil { // If the blob does not exist, return no reader. if errors.Is(err, snapshot.ErrBlobNotExist) { + log.Warnf("Blob %s does not exist", offlineScannerDefsName) return nil, nil } + log.Warnf("Cannnot take a snapshot of Blob %q: %v", offlineScannerDefsName, err) return nil, err } modTime := time.Time{} @@ -336,8 +339,7 @@ func (h *httpHandler) openOfflineBlob(ctx context.Context) (*vulDefFile, error) func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string) (file *vulDefFile, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - file, err = h.openOfflineBlob(ctx) - return + return h.openOfflineBlob(ctx) } // Start the updater, can be called multiple times for the same uuid, but will diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index 4553eedb53751..8edc6ca12ff23 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -61,8 +61,10 @@ func (s *handlerTestSuite) SetupTest() { func (s *handlerTestSuite) TearDownSuite() { entries, err := os.ReadDir(s.tmpDir) s.NoError(err) - s.Len(entries, 1) - s.True(strings.HasPrefix(entries[0].Name(), definitionsBaseDir)) + s.Less(len(entries), 1) + if len(entries) == 1 { + s.True(strings.HasPrefix(entries[0].Name(), definitionsBaseDir)) + } s.testDB.Teardown(s.T()) utils.IgnoreError(func() error { return os.RemoveAll(s.tmpDir) }) From f254423265552fe9a42ad918833f53acc4d595d0 Mon Sep 17 00:00:00 2001 From: cdu Date: Mon, 15 May 2023 13:02:38 -0700 Subject: [PATCH 18/40] revert sneak --- central/probeupload/manager/manager_impl.go | 1 + 1 file changed, 1 insertion(+) diff --git a/central/probeupload/manager/manager_impl.go b/central/probeupload/manager/manager_impl.go index da22f0b6a55d2..6462785aa042a 100644 --- a/central/probeupload/manager/manager_impl.go +++ b/central/probeupload/manager/manager_impl.go @@ -49,6 +49,7 @@ type manager struct { func newManager(persistenceRoot string) *manager { return &manager{ + rootDir: filepath.Join(persistenceRoot, rootDirName), freeDiskThreshold: defaultFreeDiskThreshold, } } From b2f6f6b62309ca0cb7b98f1de9446b7cecabda14 Mon Sep 17 00:00:00 2001 From: cdu Date: Mon, 15 May 2023 13:02:38 -0700 Subject: [PATCH 19/40] revert sneak --- central/probeupload/manager/manager_impl.go | 1 + 1 file changed, 1 insertion(+) diff --git a/central/probeupload/manager/manager_impl.go b/central/probeupload/manager/manager_impl.go index da22f0b6a55d2..6462785aa042a 100644 --- a/central/probeupload/manager/manager_impl.go +++ b/central/probeupload/manager/manager_impl.go @@ -49,6 +49,7 @@ type manager struct { func newManager(persistenceRoot string) *manager { return &manager{ + rootDir: filepath.Join(persistenceRoot, rootDirName), freeDiskThreshold: defaultFreeDiskThreshold, } } From 39203562578d416eb5b5a8041bd73e3cf83f9db5 Mon Sep 17 00:00:00 2001 From: cdu Date: Mon, 15 May 2023 15:28:45 -0700 Subject: [PATCH 20/40] fix --- central/scannerdefinitions/handler/handler.go | 15 ++++++--------- .../scannerdefinitions/handler/handler_test.go | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index c78e28e2e5f27..77c1a47012bd8 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -35,16 +35,13 @@ import ( const ( definitionsBaseDir = "scannerdefinitions" - scannerDefinitionBlobName = "offline.scanner.definitions" - scannerDefinitionTemp = "scanner-defs-*.zip" - // scannerDefsSubZipName represents the offline zip bundle for CVEs for Scanner. scannerDefsSubZipName = "scanner-defs.zip" // K8sIstioCveZipName represents the zip bundle for k8s/istio CVEs. K8sIstioCveZipName = "k8s-istio.zip" - // offlineScannerDefsName represents the offline/fallback zip bundle for CVEs for Scanner. - offlineScannerDefsName = scannerDefsSubZipName + // offlineScannerDefinitionBlobName represents the blob name of offline/fallback zip bundle for CVEs for Scanner. + offlineScannerDefinitionBlobName = "/offline/scanner/" + scannerDefsSubZipName scannerUpdateDomain = "https://definitions.stackrox.io" scannerUpdateURLSuffix = "diff.zip" @@ -216,7 +213,7 @@ func (h *httpHandler) handleScannerDefsFile(ctx context.Context, zipF *zip.File) // POST requests only update the offline feed. b := &storage.Blob{ - Name: scannerDefinitionBlobName, + Name: offlineScannerDefinitionBlobName, LastUpdated: timestamp.TimestampNow(), ModifiedTime: timestamp.TimestampNow(), Length: zipF.FileInfo().Size(), @@ -314,14 +311,14 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } func (h *httpHandler) openOfflineBlob(ctx context.Context) (*vulDefFile, error) { - snap, err := snapshot.TakeBlobSnapshot(sac.WithAllAccess(ctx), h.blobStore, offlineScannerDefsName) + snap, err := snapshot.TakeBlobSnapshot(sac.WithAllAccess(ctx), h.blobStore, offlineScannerDefinitionBlobName) if err != nil { // If the blob does not exist, return no reader. if errors.Is(err, snapshot.ErrBlobNotExist) { - log.Warnf("Blob %s does not exist", offlineScannerDefsName) + log.Warnf("Blob %s does not exist", offlineScannerDefinitionBlobName) return nil, nil } - log.Warnf("Cannnot take a snapshot of Blob %q: %v", offlineScannerDefsName, err) + log.Warnf("Cannnot take a snapshot of Blob %q: %v", offlineScannerDefinitionBlobName, err) return nil, err } modTime := time.Time{} diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index 8edc6ca12ff23..d83eec49a4359 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -61,7 +61,7 @@ func (s *handlerTestSuite) SetupTest() { func (s *handlerTestSuite) TearDownSuite() { entries, err := os.ReadDir(s.tmpDir) s.NoError(err) - s.Less(len(entries), 1) + s.LessOrEqual(len(entries), 1) if len(entries) == 1 { s.True(strings.HasPrefix(entries[0].Name(), definitionsBaseDir)) } @@ -176,7 +176,7 @@ func (s *handlerTestSuite) mustWriteOffline(content string, modTime time.Time) { modifiedTime, err := types.TimestampProto(modTime) s.NoError(err) blob := &storage.Blob{ - Name: offlineScannerDefsName, + Name: offlineScannerDefinitionBlobName, ModifiedTime: modifiedTime, } s.Require().NoError(s.datastore.Upsert(s.ctx, blob, bytes.NewBuffer([]byte(content)))) From 595d12b968ace50b4dbf250014334055bd497f74 Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 09:32:08 -0700 Subject: [PATCH 21/40] stage --- .../migration.go | 106 +++++++++++++ .../migration_test.go | 129 +++++++++++++++ .../schema/blobs.go | 35 +++++ .../schema/convert_blobs.go | 28 ++++ .../schema/convert_blobs_test.go | 20 +++ .../schema/gen.go | 3 + .../gorm/largeobject/large_objects.go | 147 ++++++++++++++++++ pkg/postgres/gorm/largeobject/utils_test.go | 48 ++++++ 8 files changed, 516 insertions(+) create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go create mode 100644 pkg/postgres/gorm/largeobject/large_objects.go create mode 100644 pkg/postgres/gorm/largeobject/utils_test.go diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go new file mode 100644 index 0000000000000..535077e175b9f --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go @@ -0,0 +1,106 @@ +package m179tom180 + +import ( + "context" + "database/sql" + "os" + + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/migrator/migrations" + "github.com/stackrox/rox/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema" + "github.com/stackrox/rox/migrator/types" + "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/pkg/sac" + "gorm.io/gorm" +) + +const ( + scannerDefBlobName = "/offline/scanner/scanner-defs.zip" + scannerDefPath = "/var/lib/stackrox/scannerdefinitions/scanner-defs.zip" +) + +var ( + migration = types.Migration{ + StartingSeqNum: 180, + VersionAfter: &storage.Version{SeqNum: 181}, + Run: func(databases *types.Databases) error { + err := moveToBlobs(databases.GormDB) + if err != nil { + return errors.Wrap(err, "updating policies") + } + return nil + }, + } + log = logging.LoggerForModule() + toBeMigrated = map[string]string{ + scannerDefPath: scannerDefBlobName, + } +) + +func moveToBlobs(db *gorm.DB) (err error) { + ctx := sac.WithAllAccess(context.Background()) + db = db.WithContext(ctx).Table(schema.BlobsTableName) + if err := db.WithContext(ctx).AutoMigrate(schema.CreateTableBlobsStmt.GormModel); err != nil { + return err + } + tx := db.Model(schema.CreateTableBlobsStmt.GormModel).Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + defer func() { + if err != nil { + tx.Rollback() + } + }() + + for p, blobName := range toBeMigrated { + f, err := os.Open(p) + if errors.Is(err, os.ErrNotExist) { + continue + } + if err != nil { + return err + } + target := &schema.Blobs{Name: blobName} + result := tx.Take(target) + if result.Error != nil { + return result.Error + } + var blob *storage.Blob + if result.RowsAffected == 0 { + // Create + // err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid) + var oid int + tx.Select("lo_create(0)").Find() + tx.Exec("SELECT lo_create(0)").Find(&oid) + blob = &storage.Blob{ + Name: blobName, + Oid: 0, + Length: 0, + ModifiedTime: nil, + } + } else { + // Update + existingBlob, err := schema.ConvertBlobToProto(target) + if err != nil { + return err + } + blob = &storage.Blob{ + Name: blobName, + Oid: existingBlob.Oid, + Length: 0, + ModifiedTime: nil, + } + } + blobModel, err := schema.ConvertBlobFromProto(blob) + if err != nil { + return err + } + tx.Exec("") + tx = tx.FirstOrCreate(blobModel) + + } + return tx.Commit().Error +} + +func init() { + migrations.MustRegisterMigration(migration) +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go new file mode 100644 index 0000000000000..5b8a68f457b39 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -0,0 +1,129 @@ +//go:build sql_integration + +package m179tom180 + +import ( + "context" + "fmt" + "testing" + + "github.com/stackrox/rox/generated/storage" + frozenSchema "github.com/stackrox/rox/migrator/migrations/frozenschema/v73" + policyPostgresStore "github.com/stackrox/rox/migrator/migrations/m_179_to_m_180_openshift_policy_exclusions/postgres" + pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" + "github.com/stackrox/rox/migrator/types" + "github.com/stackrox/rox/pkg/fixtures" + "github.com/stackrox/rox/pkg/postgres/pgutils" + "github.com/stackrox/rox/pkg/sac" + "github.com/stackrox/rox/pkg/search" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type categoriesMigrationTestSuite struct { + suite.Suite + + db *pghelper.TestPostgres + policyStore policyPostgresStore.Store +} + +func TestMigration(t *testing.T) { + suite.Run(t, new(categoriesMigrationTestSuite)) +} + +func (s *categoriesMigrationTestSuite) SetupTest() { + s.db = pghelper.ForT(s.T(), true) + s.policyStore = policyPostgresStore.New(s.db.DB) + pgutils.CreateTableFromModel(context.Background(), s.db.GetGormDB(), frozenSchema.CreateTablePoliciesStmt) + +} + +func (s *categoriesMigrationTestSuite) TearDownTest() { + s.db.Teardown(s.T()) +} + +func (s *categoriesMigrationTestSuite) TestMigration() { + ctx := sac.WithAllAccess(context.Background()) + testPolicy := fixtures.GetPolicy() + testPolicy.Id = "ed8c7957-14de-40bc-aeab-d27ceeecfa7b" + testPolicy.Name = "Iptables Executed in Privileged Container" + testPolicy.Description = "Alert on privileged pods that execute iptables" + testPolicy.PolicySections = []*storage.PolicySection{ + { + PolicyGroups: []*storage.PolicyGroup{ + { + FieldName: "Privileged Container", + Values: []*storage.PolicyValue{ + { + Value: "true", + }, + }, + }, + { + FieldName: "Process Name", + Values: []*storage.PolicyValue{ + { + Value: "iptables", + }, + }, + }, + { + FieldName: "Process UID", + Values: []*storage.PolicyValue{ + { + Value: "0", + }, + }, + }, + }, + }, + } + require.NoError(s.T(), s.policyStore.Upsert(ctx, testPolicy)) + // insert other policies in db for migration to run successfully + policies := []string{ + "fb8f8732-c31d-496b-8fb1-d5abe6056e27", + "880fd131-46f0-43d2-82c9-547f5aa7e043", + "47cb9e0a-879a-417b-9a8f-de644d7c8a77", + "6226d4ad-7619-4a0b-a160-46373cfcee66", + "436811e7-892f-4da6-a0f5-8cc459f1b954", + "742e0361-bddd-4a2d-8758-f2af6197f61d", + "16c95922-08c4-41b6-a721-dc4b2a806632", + "fe9de18b-86db-44d5-a7c4-74173ccffe2e", + "dce17697-1b72-49d2-b18a-05d893cd9368", + "f4996314-c3d7-4553-803b-b24ce7febe48", + "a9b9ecf7-9707-4e32-8b62-d03018ed454f", + "32d770b9-c6ba-4398-b48a-0c3e807644ed", + "f95ff08d-130a-465a-a27e-32ed1fb05555", + } + + policyName := "policy description %d" + for i := 0; i < len(policies); i++ { + require.NoError(s.T(), s.policyStore.Upsert(ctx, &storage.Policy{ + Id: policies[i], + Name: fmt.Sprintf(policyName, i), + })) + } + dbs := &types.Databases{ + PostgresDB: s.db.DB, + GormDB: s.db.GetGormDB(), + } + + q := search.NewQueryBuilder().AddExactMatches(search.PolicyID, testPolicy.GetId()).ProtoQuery() + policyPremigration, err := s.policyStore.GetByQuery(ctx, q) + s.NoError(err) + s.Empty(policyPremigration[0].Exclusions) + s.NoError(migration.Run(dbs)) + expectedExclusions := []string{"Don't alert on ovnkube-node deployment in openshift-ovn-kubernetes Namespace", + "Don't alert on haproxy-* deployment in openshift-vsphere-infra namespace", + "Don't alert on keepalived-* deployment in openshift-vsphere-infra namespace", + "Don't alert on coredns-* deployment in openshift-vsphere-infra namespace"} + query := search.NewQueryBuilder().AddExactMatches(search.PolicyID, testPolicy.GetId()).ProtoQuery() + policy, err := s.policyStore.GetByQuery(ctx, query) + s.NoError(err) + var actualExclusions []string + for _, excl := range policy[0].Exclusions { + actualExclusions = append(actualExclusions, excl.Name) + } + s.ElementsMatch(actualExclusions, expectedExclusions, "exclusion do not match after migration") + +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go new file mode 100644 index 0000000000000..c6401dcf1b832 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go @@ -0,0 +1,35 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. + +package schema + +import ( + "reflect" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/postgres" + "github.com/stackrox/rox/pkg/postgres/walker" +) + +var ( + // CreateTableBlobsStmt holds the create statement for table `blobs`. + CreateTableBlobsStmt = &postgres.CreateStmts{ + GormModel: (*Blobs)(nil), + Children: []*postgres.CreateStmts{}, + } + + // BlobsSchema is the go schema for table `blobs`. + BlobsSchema = func() *walker.Schema { + schema := walker.Walk(reflect.TypeOf((*storage.Blob)(nil)), "blobs") + return schema + }() +) + +const ( + BlobsTableName = "blobs" +) + +// Blobs holds the Gorm model for Postgres table `blobs`. +type Blobs struct { + Name string `gorm:"column:name;type:varchar;primaryKey"` + Serialized []byte `gorm:"column:serialized;type:bytea"` +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go new file mode 100644 index 0000000000000..25223842c6131 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go @@ -0,0 +1,28 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. +package schema + +import ( + "github.com/stackrox/rox/generated/storage" +) + +// ConvertBlobFromProto converts a `*storage.Blob` to Gorm model +func ConvertBlobFromProto(obj *storage.Blob) (*Blobs, error) { + serialized, err := obj.Marshal() + if err != nil { + return nil, err + } + model := &Blobs{ + Name: obj.GetName(), + Serialized: serialized, + } + return model, nil +} + +// ConvertBlobToProto converts Gorm model `Blobs` to its protobuf type object +func ConvertBlobToProto(m *Blobs) (*storage.Blob, error) { + var msg storage.Blob + if err := msg.Unmarshal(m.Serialized); err != nil { + return nil, err + } + return &msg, nil +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go new file mode 100644 index 0000000000000..a2f2300ee3c51 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go @@ -0,0 +1,20 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. +package schema + +import ( + "testing" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/testutils" + "github.com/stretchr/testify/assert" +) + +func TestBlobSerialization(t *testing.T) { + obj := &storage.Blob{} + assert.NoError(t, testutils.FullInit(obj, testutils.UniqueInitializer(), testutils.JSONFieldsFilter)) + m, err := ConvertBlobFromProto(obj) + assert.NoError(t, err) + conv, err := ConvertBlobToProto(m) + assert.NoError(t, err) + assert.Equal(t, obj, conv) +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go new file mode 100644 index 0000000000000..8265e535a968e --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go @@ -0,0 +1,3 @@ +package schema + +//go:generate pg-schema-migration-helper --type=storage.Blob diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go new file mode 100644 index 0000000000000..c5a814cabd124 --- /dev/null +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -0,0 +1,147 @@ +package largeobject + +import ( + "errors" + "io" + + "gorm.io/gorm" +) + +// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it +// was created. +// +// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html +type LargeObjects struct { + *gorm.DB +} + +type Mode int32 + +const ( + ModeWrite Mode = 0x20000 + ModeRead Mode = 0x40000 +) + +// Create creates a new large object with an unused OID assigned +func (o *LargeObjects) Create() (uint32, error) { + result := o.Raw("SELECT lo_create(?)", 0) + if result.Error != nil { + return 0, result.Error + } + var oid uint32 + result.Scan(&oid) + return oid, nil +} + +// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large +// object. +func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { + var fd int32 + result := o.Raw("select lo_open(?, ?)", oid, mode).Scan(&fd) + if result.Error != nil { + return nil, result.Error + } + return &LargeObject{fd: fd, tx: o.DB}, nil +} + +// Unlink removes a large object from the database. +func (o *LargeObjects) Unlink(oid uint32) error { + var count int32 + result := o.Raw("select lo_unlink(?)", oid).Scan(&count) + if result.Error != nil { + return result.Error + } + + if count != 1 { + return errors.New("failed to remove large object") + } + + return nil +} + +func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { + obj, err := o.Open(oid, ModeWrite) + if err != nil { + return err + } + _, err = io.Copy(obj, r) + + return err +} + +func (o *LargeObjects) Get(oid uint32, w io.Writer) error { + obj, err := o.Open(oid, ModeRead) + if err != nil { + return err + } + _, err = io.Copy(w, obj) + return err +} + +// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized +// in. It uses the context it was initialized with for all operations. It implements these interfaces: +// +// io.Writer +// io.Reader +// io.Seeker +// io.Closer +type LargeObject struct { + tx *gorm.DB + fd int32 +} + +// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. +func (o *LargeObject) Write(p []byte) (int, error) { + var n int + result := o.tx.Raw("select lowrite($1, $2)", o.fd, p).Scan(&n) + if result.Error != nil { + return n, result.Error + } + + if n < 0 { + return 0, errors.New("failed to write to large object") + } + + return n, nil +} + +// Read reads up to len(p) bytes into p returning the number of bytes read. +func (o *LargeObject) Read(p []byte) (n int, err error) { + var res []byte + result := o.tx.Raw("select loread($1, $2)", o.fd, len(p)).Scan(&res) + copy(p, res) + if result.Error != nil { + return len(res), result.Error + } + + if len(res) < len(p) { + err = io.EOF + } + return len(res), err +} + +// Seek moves the current location pointer to the new location specified by offset. +func (o *LargeObject) Seek(offset int64, whence int) (int64, error) { + var n int64 + result := o.tx.Raw("select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + return n, result.Error +} + +// Tell returns the current read or write location of the large object descriptor. +func (o *LargeObject) Tell() (int64, error) { + var n int64 + result := o.tx.Raw("select lo_tell64($1)", o.fd).Scan(&n) + return n, result.Error +} + +// Truncate the large object to size. +func (o *LargeObject) Truncate(size int64) (err error) { + result := o.tx.Raw("select lo_truncate64(?, ?)", o.fd, size) + return result.Error +} + +// Close the large object descriptor. +func (o *LargeObject) Close() error { + result := o.tx.Raw("select lo_close(?)", o.fd) + return result.Error +} diff --git a/pkg/postgres/gorm/largeobject/utils_test.go b/pkg/postgres/gorm/largeobject/utils_test.go new file mode 100644 index 0000000000000..d736860e1a822 --- /dev/null +++ b/pkg/postgres/gorm/largeobject/utils_test.go @@ -0,0 +1,48 @@ +package largeobject + +import ( + "bytes" + "context" + "crypto/rand" + "testing" + + pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" + "github.com/stretchr/testify/suite" +) + +type gormUtilsTestSuite struct { + suite.Suite + + db *pghelper.TestPostgres + ctx context.Context +} + +func TestMigration(t *testing.T) { + suite.Run(t, new(gormUtilsTestSuite)) +} + +func (s *gormUtilsTestSuite) SetupTest() { + s.db = pghelper.ForT(s.T(), true) + s.ctx = context.Background() + +} + +func (s *gormUtilsTestSuite) TearDownTest() { + s.db.Teardown(s.T()) +} + +func (s *gormUtilsTestSuite) TestMigration() { + randomData := make([]byte, 10000) + _, err := rand.Read(randomData) + s.NoError(err) + + reader := bytes.NewBuffer(randomData) + gormDB := s.db.GetGormDB() + tx := gormDB.Begin() + los := LargeObjects{tx} + oid, err := los.Create() + s.NoError(err) + err = los.Upsert(oid, reader) + s.NoError(err) + s.NoError(tx.Commit().Error) +} From 6b52d7cc8899970e852edc1d4597e5b551363f8e Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 12:18:53 -0700 Subject: [PATCH 22/40] stage --- .../gorm/largeobject/large_objects.go | 81 +++-- pkg/postgres/gorm/largeobject/utils_test.go | 317 +++++++++++++++++- 2 files changed, 364 insertions(+), 34 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index c5a814cabd124..1c2eec2aba985 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -12,7 +12,7 @@ import ( // // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html type LargeObjects struct { - *gorm.DB + tx *gorm.DB } type Mode int32 @@ -24,38 +24,36 @@ const ( // Create creates a new large object with an unused OID assigned func (o *LargeObjects) Create() (uint32, error) { - result := o.Raw("SELECT lo_create(?)", 0) - if result.Error != nil { - return 0, result.Error + o.tx = o.tx.Raw("SELECT lo_create(?)", 0) + if err := o.tx.Error; err != nil { + return 0, err } var oid uint32 - result.Scan(&oid) - return oid, nil + o.tx = o.tx.Scan(&oid) + return oid, o.tx.Error } // Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large // object. func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { var fd int32 - result := o.Raw("select lo_open(?, ?)", oid, mode).Scan(&fd) - if result.Error != nil { - return nil, result.Error + o.tx = o.tx.Raw("select lo_open(?, ?)", oid, mode).Scan(&fd) + if err := o.tx.Error; err != nil { + return nil, err } - return &LargeObject{fd: fd, tx: o.DB}, nil + return &LargeObject{fd: fd, tx: o.tx}, nil } // Unlink removes a large object from the database. func (o *LargeObjects) Unlink(oid uint32) error { var count int32 - result := o.Raw("select lo_unlink(?)", oid).Scan(&count) - if result.Error != nil { - return result.Error + o.tx = o.tx.Raw("select lo_unlink(?)", oid).Scan(&count) + if err := o.tx.Error; err != nil { + return err } - if count != 1 { return errors.New("failed to remove large object") } - return nil } @@ -63,10 +61,17 @@ func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { obj, err := o.Open(oid, ModeWrite) if err != nil { return err - } + } /* + err = obj.Truncate(1) + if err != nil { + return err + } + if _, err = obj.Seek(0, io.SeekStart); err != nil { + return err + }*/ _, err = io.Copy(obj, r) - return err + return obj.Close() } func (o *LargeObjects) Get(oid uint32, w io.Writer) error { @@ -93,9 +98,15 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { var n int - result := o.tx.Raw("select lowrite($1, $2)", o.fd, p).Scan(&n) - if result.Error != nil { - return n, result.Error + o.tx = o.tx.Raw("select lowrite($1, $2)", o.fd, p) + if err := o.tx.Error; err != nil { + return n, err + } + if err := o.tx.Scan(&n).Error; err != nil { + return n, err + } + if err := o.tx.Error; err != nil { + return n, err } if n < 0 { @@ -107,11 +118,17 @@ func (o *LargeObject) Write(p []byte) (int, error) { // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (n int, err error) { - var res []byte - result := o.tx.Raw("select loread($1, $2)", o.fd, len(p)).Scan(&res) + var res []byte = make([]byte, 0, len(p)) + o.tx = o.tx.Raw("select loread($1, $2)", o.fd, len(p)) + if err = o.tx.Error; err != nil { + return 0, err + } + if err = o.tx.Row().Scan(&res); err != nil { + return 0, err + } copy(p, res) - if result.Error != nil { - return len(res), result.Error + if err = o.tx.Error; err != nil { + return len(res), err } if len(res) < len(p) { @@ -123,25 +140,25 @@ func (o *LargeObject) Read(p []byte) (n int, err error) { // Seek moves the current location pointer to the new location specified by offset. func (o *LargeObject) Seek(offset int64, whence int) (int64, error) { var n int64 - result := o.tx.Raw("select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) - return n, result.Error + o.tx = o.tx.Raw("select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + return n, o.tx.Error } // Tell returns the current read or write location of the large object descriptor. func (o *LargeObject) Tell() (int64, error) { var n int64 - result := o.tx.Raw("select lo_tell64($1)", o.fd).Scan(&n) - return n, result.Error + o.tx = o.tx.Raw("select lo_tell64($1)", o.fd).Scan(&n) + return n, o.tx.Error } // Truncate the large object to size. func (o *LargeObject) Truncate(size int64) (err error) { - result := o.tx.Raw("select lo_truncate64(?, ?)", o.fd, size) - return result.Error + o.tx = o.tx.Raw("select lo_truncate64(?, ?)", o.fd, size) + return o.tx.Error } // Close the large object descriptor. func (o *LargeObject) Close() error { - result := o.tx.Raw("select lo_close(?)", o.fd) - return result.Error + o.tx = o.tx.Raw("select lo_close(?)", o.fd) + return o.tx.Error } diff --git a/pkg/postgres/gorm/largeobject/utils_test.go b/pkg/postgres/gorm/largeobject/utils_test.go index d736860e1a822..c3b1f62d72cc6 100644 --- a/pkg/postgres/gorm/largeobject/utils_test.go +++ b/pkg/postgres/gorm/largeobject/utils_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "database/sql" "testing" pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" @@ -36,13 +37,325 @@ func (s *gormUtilsTestSuite) TestMigration() { _, err := rand.Read(randomData) s.NoError(err) - reader := bytes.NewBuffer(randomData) gormDB := s.db.GetGormDB() - tx := gormDB.Begin() + + reader := bytes.NewBuffer(randomData) + tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los := LargeObjects{tx} oid, err := los.Create() s.NoError(err) err = los.Upsert(oid, reader) s.NoError(err) s.NoError(tx.Commit().Error) + + tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los = LargeObjects{tx} + writer := bytes.NewBuffer([]byte{}) + s.Require().NoError(los.Get(oid, writer)) + + s.Require().Equal(randomData, writer.Bytes()) + reader = bytes.NewBuffer([]byte("hi")) + tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los = LargeObjects{tx} + err = los.Upsert(oid, reader) + s.NoError(err) + s.NoError(tx.Commit().Error) + + tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los = LargeObjects{tx} + writer = bytes.NewBuffer([]byte{}) + s.Require().NoError(los.Get(oid, writer)) + s.Require().Equal([]byte("hi"), writer.Bytes()) +} + +/* +func TestLargeObjects(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + skipCockroachDB(t, conn, "Server does support large objects") + + tx, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + testLargeObjects(t, ctx, tx) +} + +func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + config.PreferSimpleProtocol = true + + conn, err := pgx.ConnectConfig(ctx, config) + if err != nil { + t.Fatal(err) + } + + skipCockroachDB(t, conn, "Server does support large objects") + + tx, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + testLargeObjects(t, ctx, tx) +} + +func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { + lo := tx.LargeObjects() + + id, err := lo.Create(ctx, 0) + if err != nil { + t.Fatal(err) + } + + obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) + if err != nil { + t.Fatal(err) + } + + n, err := obj.Write([]byte("testing")) + if err != nil { + t.Fatal(err) + } + if n != 7 { + t.Errorf("Expected n to be 7, got %d", n) + } + + pos, err := obj.Seek(1, 0) + if err != nil { + t.Fatal(err) + } + if pos != 1 { + t.Errorf("Expected pos to be 1, got %d", pos) + } + + res := make([]byte, 6) + n, err = obj.Read(res) + if err != nil { + t.Fatal(err) + } + if string(res) != "esting" { + t.Errorf(`Expected res to be "esting", got %q`, res) + } + if n != 6 { + t.Errorf("Expected n to be 6, got %d", n) + } + + n, err = obj.Read(res) + if err != io.EOF { + t.Error("Expected io.EOF, go nil") + } + if n != 0 { + t.Errorf("Expected n to be 0, got %d", n) + } + + pos, err = obj.Tell() + if err != nil { + t.Fatal(err) + } + if pos != 7 { + t.Errorf("Expected pos to be 7, got %d", pos) + } + + err = obj.Truncate(1) + if err != nil { + t.Fatal(err) + } + + pos, err = obj.Seek(-1, 2) + if err != nil { + t.Fatal(err) + } + if pos != 0 { + t.Errorf("Expected pos to be 0, got %d", pos) + } + + res = make([]byte, 2) + n, err = obj.Read(res) + if err != io.EOF { + t.Errorf("Expected err to be io.EOF, got %v", err) + } + if n != 1 { + t.Errorf("Expected n to be 1, got %d", n) + } + if res[0] != 't' { + t.Errorf("Expected res[0] to be 't', got %v", res[0]) + } + + err = obj.Close() + if err != nil { + t.Fatal(err) + } + + err = lo.Unlink(ctx, id) + if err != nil { + t.Fatal(err) + } + + _, err = lo.Open(ctx, id, pgx.LargeObjectModeRead) + if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { + t.Errorf("Expected undefined_object error (42704), got %#v", err) + } +} + +func TestLargeObjectsMultipleTransactions(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + skipCockroachDB(t, conn, "Server does support large objects") + + tx, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + lo := tx.LargeObjects() + + id, err := lo.Create(ctx, 0) + if err != nil { + t.Fatal(err) + } + + obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite) + if err != nil { + t.Fatal(err) + } + + n, err := obj.Write([]byte("testing")) + if err != nil { + t.Fatal(err) + } + if n != 7 { + t.Errorf("Expected n to be 7, got %d", n) + } + + // Commit the first transaction + err = tx.Commit(ctx) + if err != nil { + t.Fatal(err) + } + + // IMPORTANT: Use the same connection for another query + query := `select n from generate_series(1,10) n` + rows, err := conn.Query(ctx, query) + if err != nil { + t.Fatal(err) + } + rows.Close() + + // Start a new transaction + tx2, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + lo2 := tx2.LargeObjects() + + // Reopen the large object in the new transaction + obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) + if err != nil { + t.Fatal(err) + } + + pos, err := obj2.Seek(1, 0) + if err != nil { + t.Fatal(err) + } + if pos != 1 { + t.Errorf("Expected pos to be 1, got %d", pos) + } + + res := make([]byte, 6) + n, err = obj2.Read(res) + if err != nil { + t.Fatal(err) + } + if string(res) != "esting" { + t.Errorf(`Expected res to be "esting", got %q`, res) + } + if n != 6 { + t.Errorf("Expected n to be 6, got %d", n) + } + + n, err = obj2.Read(res) + if err != io.EOF { + t.Error("Expected io.EOF, go nil") + } + if n != 0 { + t.Errorf("Expected n to be 0, got %d", n) + } + + pos, err = obj2.Tell() + if err != nil { + t.Fatal(err) + } + if pos != 7 { + t.Errorf("Expected pos to be 7, got %d", pos) + } + + err = obj2.Truncate(1) + if err != nil { + t.Fatal(err) + } + + pos, err = obj2.Seek(-1, 2) + if err != nil { + t.Fatal(err) + } + if pos != 0 { + t.Errorf("Expected pos to be 0, got %d", pos) + } + + res = make([]byte, 2) + n, err = obj2.Read(res) + if err != io.EOF { + t.Errorf("Expected err to be io.EOF, got %v", err) + } + if n != 1 { + t.Errorf("Expected n to be 1, got %d", n) + } + if res[0] != 't' { + t.Errorf("Expected res[0] to be 't', got %v", res[0]) + } + + err = obj2.Close() + if err != nil { + t.Fatal(err) + } + + err = lo2.Unlink(ctx, id) + if err != nil { + t.Fatal(err) + } + + _, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead) + if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { + t.Errorf("Expected undefined_object error (42704), got %#v", err) + } } +*/ From 375ad885e58110a2fc42a5e1d4601328cd6d39c0 Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 16:46:13 -0700 Subject: [PATCH 23/40] stage --- .../gorm/largeobject/large_objects.go | 61 +++- pkg/postgres/gorm/largeobject/utils_test.go | 342 +++++------------- 2 files changed, 142 insertions(+), 261 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index 1c2eec2aba985..a79ea50b3698a 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -58,18 +58,23 @@ func (o *LargeObjects) Unlink(oid uint32) error { } func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { - obj, err := o.Open(oid, ModeWrite) + obj, err := o.Open(oid, ModeWrite|ModeRead) if err != nil { return err - } /* - err = obj.Truncate(1) - if err != nil { - return err - } - if _, err = obj.Seek(0, io.SeekStart); err != nil { - return err - }*/ + } + _, err = obj.Truncate(0) + if err != nil { + return err + } + obj.Close() + obj, err = o.Open(oid, ModeWrite) + if err != nil { + return err + } _, err = io.Copy(obj, r) + if err != nil { + return err + } return obj.Close() } @@ -98,11 +103,11 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { var n int - o.tx = o.tx.Raw("select lowrite($1, $2)", o.fd, p) + o.tx = o.tx.Raw("select lowrite(?, ?)", o.fd, p) if err := o.tx.Error; err != nil { return n, err } - if err := o.tx.Scan(&n).Error; err != nil { + if err := o.tx.Row().Scan(&n); err != nil { return n, err } if err := o.tx.Error; err != nil { @@ -119,7 +124,7 @@ func (o *LargeObject) Write(p []byte) (int, error) { // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (n int, err error) { var res []byte = make([]byte, 0, len(p)) - o.tx = o.tx.Raw("select loread($1, $2)", o.fd, len(p)) + o.tx = o.tx.Raw("select loread(?, ?)", o.fd, len(p)) if err = o.tx.Error; err != nil { return 0, err } @@ -140,7 +145,15 @@ func (o *LargeObject) Read(p []byte) (n int, err error) { // Seek moves the current location pointer to the new location specified by offset. func (o *LargeObject) Seek(offset int64, whence int) (int64, error) { var n int64 - o.tx = o.tx.Raw("select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + result := o.tx.Raw("select lo_lseek64(?, ?, ?)", o.fd, offset, whence) + if result.Error != nil { + return 0, o.tx.Error + } + row := o.tx.Row() + row.Scan(&n) + if result.Error != nil { + return 0, result.Error + } return n, o.tx.Error } @@ -152,13 +165,27 @@ func (o *LargeObject) Tell() (int64, error) { } // Truncate the large object to size. -func (o *LargeObject) Truncate(size int64) (err error) { - o.tx = o.tx.Raw("select lo_truncate64(?, ?)", o.fd, size) - return o.tx.Error +func (o *LargeObject) Truncate(size int64) (n int, err error) { + result := o.tx.Raw("select lo_truncate64(?, ?)", o.fd, size).Scan(&n) + return n, result.Error } // Close the large object descriptor. func (o *LargeObject) Close() error { - o.tx = o.tx.Raw("select lo_close(?)", o.fd) + var n int + o.tx = o.tx.Raw("select lo_close(?)", o.fd).Scan(&n) return o.tx.Error } + +/* +{ var n int64 +o.tx = o.tx.Raw("select lo_tell64(?)", o.fd) +if o.tx.Error != nil { +return n, o.tx.Error +} +row := o.tx.Row() +if row.Err() != nil { +return n, row.Err() +} +err := row.Scan(&n) +return n, err}*/ diff --git a/pkg/postgres/gorm/largeobject/utils_test.go b/pkg/postgres/gorm/largeobject/utils_test.go index c3b1f62d72cc6..ac3b39d6f683e 100644 --- a/pkg/postgres/gorm/largeobject/utils_test.go +++ b/pkg/postgres/gorm/largeobject/utils_test.go @@ -5,35 +5,36 @@ import ( "context" "crypto/rand" "database/sql" + "io" "testing" pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) -type gormUtilsTestSuite struct { +type GormUtilsTestSuite struct { suite.Suite db *pghelper.TestPostgres ctx context.Context } -func TestMigration(t *testing.T) { - suite.Run(t, new(gormUtilsTestSuite)) +func TestLargeObjects(t *testing.T) { + suite.Run(t, new(GormUtilsTestSuite)) } -func (s *gormUtilsTestSuite) SetupTest() { +func (s *GormUtilsTestSuite) SetupTest() { s.db = pghelper.ForT(s.T(), true) s.ctx = context.Background() - } -func (s *gormUtilsTestSuite) TearDownTest() { +func (s *GormUtilsTestSuite) TearDownTest() { s.db.Teardown(s.T()) } -func (s *gormUtilsTestSuite) TestMigration() { - randomData := make([]byte, 10000) +func (s *GormUtilsTestSuite) TestUpsertGet() { + randomData := make([]byte, 100) _, err := rand.Read(randomData) s.NoError(err) @@ -43,10 +44,10 @@ func (s *gormUtilsTestSuite) TestMigration() { tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los := LargeObjects{tx} oid, err := los.Create() - s.NoError(err) + s.Require().NoError(err) err = los.Upsert(oid, reader) - s.NoError(err) - s.NoError(tx.Commit().Error) + s.Require().NoError(err) + s.Require().NoError(tx.Commit().Error) tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} @@ -58,304 +59,157 @@ func (s *gormUtilsTestSuite) TestMigration() { tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} err = los.Upsert(oid, reader) - s.NoError(err) - s.NoError(tx.Commit().Error) + s.Require().NoError(err) + s.Require().NoError(tx.Commit().Error) tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} writer = bytes.NewBuffer([]byte{}) + writer.Reset() s.Require().NoError(los.Get(oid, writer)) s.Require().Equal([]byte("hi"), writer.Bytes()) } -/* -func TestLargeObjects(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - skipCockroachDB(t, conn, "Server does support large objects") - - tx, err := conn.Begin(ctx) - if err != nil { - t.Fatal(err) - } - - testLargeObjects(t, ctx, tx) -} - -func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { - t.Parallel() +func (s *GormUtilsTestSuite) TestLargeObject() { + ctx := context.Background() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + gormDB := s.db.GetGormDB().WithContext(ctx) - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } + tx := gormDB.Begin() + s.Require().NoError(tx.Error) - config.PreferSimpleProtocol = true - - conn, err := pgx.ConnectConfig(ctx, config) - if err != nil { - t.Fatal(err) - } - - skipCockroachDB(t, conn, "Server does support large objects") - - tx, err := conn.Begin(ctx) - if err != nil { - t.Fatal(err) - } - - testLargeObjects(t, ctx, tx) + s.testLargeObject(tx) } -func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { - lo := tx.LargeObjects() +func (s *GormUtilsTestSuite) testLargeObject(tx *gorm.DB) { + los := &LargeObjects{tx} - id, err := lo.Create(ctx, 0) - if err != nil { - t.Fatal(err) - } + id, err := los.Create() + s.Require().NoError(err) - obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) - if err != nil { - t.Fatal(err) - } + obj, err := los.Open(id, ModeWrite|ModeRead) + s.Require().NoError(err) n, err := obj.Write([]byte("testing")) - if err != nil { - t.Fatal(err) - } - if n != 7 { - t.Errorf("Expected n to be 7, got %d", n) - } + s.Require().NoError(err) + s.Require().Equal(7, n, "Expected n to be 7, got %d", n) pos, err := obj.Seek(1, 0) - if err != nil { - t.Fatal(err) - } - if pos != 1 { - t.Errorf("Expected pos to be 1, got %d", pos) - } + s.Require().NoError(err) + s.Require().Equal(int64(1), pos) res := make([]byte, 6) n, err = obj.Read(res) - if err != nil { - t.Fatal(err) - } - if string(res) != "esting" { - t.Errorf(`Expected res to be "esting", got %q`, res) - } - if n != 6 { - t.Errorf("Expected n to be 6, got %d", n) - } + s.Require().NoError(err) + s.Require().Equal("esting", string(res)) + s.Require().Equal(6, n) n, err = obj.Read(res) - if err != io.EOF { - t.Error("Expected io.EOF, go nil") - } - if n != 0 { - t.Errorf("Expected n to be 0, got %d", n) - } + s.Require().Equal(err, io.EOF) + s.Require().Zero(n) pos, err = obj.Tell() - if err != nil { - t.Fatal(err) - } - if pos != 7 { - t.Errorf("Expected pos to be 7, got %d", pos) - } - - err = obj.Truncate(1) - if err != nil { - t.Fatal(err) - } + s.Require().NoError(err) + s.Require().EqualValues(7, pos) + + n, err = obj.Truncate(1) + s.Require().NoError(err) pos, err = obj.Seek(-1, 2) - if err != nil { - t.Fatal(err) - } - if pos != 0 { - t.Errorf("Expected pos to be 0, got %d", pos) - } + s.Require().NoError(err) + s.Require().Zero(pos) res = make([]byte, 2) n, err = obj.Read(res) - if err != io.EOF { - t.Errorf("Expected err to be io.EOF, got %v", err) - } - if n != 1 { - t.Errorf("Expected n to be 1, got %d", n) - } - if res[0] != 't' { - t.Errorf("Expected res[0] to be 't', got %v", res[0]) - } + s.Require().Equal(io.EOF, err) + s.Require().Equal(1, n) + s.Require().EqualValues('t', res[0]) err = obj.Close() - if err != nil { - t.Fatal(err) - } - - err = lo.Unlink(ctx, id) - if err != nil { - t.Fatal(err) - } - - _, err = lo.Open(ctx, id, pgx.LargeObjectModeRead) - if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { - t.Errorf("Expected undefined_object error (42704), got %#v", err) - } -} - -func TestLargeObjectsMultipleTransactions(t *testing.T) { - t.Parallel() + s.Require().NoError(err) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + err = los.Unlink(id) + s.Require().NoError(err) - conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - skipCockroachDB(t, conn, "Server does support large objects") + _, err = los.Open(id, ModeRead) + s.Require().Contains(err.Error(), "does not exist (SQLSTATE 42704)") +} - tx, err := conn.Begin(ctx) - if err != nil { - t.Fatal(err) - } +func (s *GormUtilsTestSuite) TestLargeObjectsMultipleTransactions() { + ctx := context.Background() - lo := tx.LargeObjects() + gormDB := s.db.GetGormDB().WithContext(ctx) - id, err := lo.Create(ctx, 0) - if err != nil { - t.Fatal(err) - } + tx := gormDB.Begin() + // tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + s.Require().NoError(tx.Error) + los := &LargeObjects{tx} - obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite) - if err != nil { - t.Fatal(err) - } + id, err := los.Create() + s.Require().NoError(err) + obj, err := los.Open(id, ModeWrite|ModeRead) + s.Require().NoError(err) n, err := obj.Write([]byte("testing")) - if err != nil { - t.Fatal(err) - } - if n != 7 { - t.Errorf("Expected n to be 7, got %d", n) - } + s.Require().NoError(err) + s.Require().Equal(7, n, "Expected n to be 7, got %d", n) // Commit the first transaction - err = tx.Commit(ctx) - if err != nil { - t.Fatal(err) - } + s.Require().NoError(tx.Commit().Error) // IMPORTANT: Use the same connection for another query query := `select n from generate_series(1,10) n` - rows, err := conn.Query(ctx, query) - if err != nil { - t.Fatal(err) - } + rows, err := gormDB.Raw(query).Rows() + s.Require().NoError(err) rows.Close() // Start a new transaction - tx2, err := conn.Begin(ctx) - if err != nil { - t.Fatal(err) - } - - lo2 := tx2.LargeObjects() + tx2 := gormDB.Begin() + // tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + s.Require().NoError(tx.Error) + los2 := &LargeObjects{tx2} // Reopen the large object in the new transaction - obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) - if err != nil { - t.Fatal(err) - } + obj2, err := los2.Open(id, ModeWrite|ModeRead) + s.Require().NoError(err) pos, err := obj2.Seek(1, 0) - if err != nil { - t.Fatal(err) - } - if pos != 1 { - t.Errorf("Expected pos to be 1, got %d", pos) - } + s.Require().NoError(err) + s.Require().EqualValues(1, pos) res := make([]byte, 6) n, err = obj2.Read(res) - if err != nil { - t.Fatal(err) - } - if string(res) != "esting" { - t.Errorf(`Expected res to be "esting", got %q`, res) - } - if n != 6 { - t.Errorf("Expected n to be 6, got %d", n) - } + s.Require().NoError(err) + s.Require().Equal("esting", string(res)) + s.Require().Equal(6, n) n, err = obj2.Read(res) - if err != io.EOF { - t.Error("Expected io.EOF, go nil") - } - if n != 0 { - t.Errorf("Expected n to be 0, got %d", n) - } + s.Require().Equal(err, io.EOF) + s.Require().Zero(n) pos, err = obj2.Tell() - if err != nil { - t.Fatal(err) - } - if pos != 7 { - t.Errorf("Expected pos to be 7, got %d", pos) - } - - err = obj2.Truncate(1) - if err != nil { - t.Fatal(err) - } + s.Require().NoError(err) + s.Require().EqualValues(7, pos) + + n, err = obj2.Truncate(1) + s.Require().NoError(err) pos, err = obj2.Seek(-1, 2) - if err != nil { - t.Fatal(err) - } - if pos != 0 { - t.Errorf("Expected pos to be 0, got %d", pos) - } + s.Require().NoError(err) + s.Require().Zero(pos) res = make([]byte, 2) n, err = obj2.Read(res) - if err != io.EOF { - t.Errorf("Expected err to be io.EOF, got %v", err) - } - if n != 1 { - t.Errorf("Expected n to be 1, got %d", n) - } - if res[0] != 't' { - t.Errorf("Expected res[0] to be 't', got %v", res[0]) - } + s.Require().Equal(io.EOF, err) + s.Require().Equal(1, n) + s.Require().EqualValues('t', res[0]) err = obj2.Close() - if err != nil { - t.Fatal(err) - } - - err = lo2.Unlink(ctx, id) - if err != nil { - t.Fatal(err) - } - - _, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead) - if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { - t.Errorf("Expected undefined_object error (42704), got %#v", err) - } + s.Require().NoError(err) + + err = los2.Unlink(id) + s.Require().NoError(err) + + _, err = los2.Open(id, ModeRead) + s.Require().Contains(err.Error(), "does not exist (SQLSTATE 42704)") } -*/ From 436ecc030e1d862942b4ce6cb4d873b8773ac04a Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 17:41:12 -0700 Subject: [PATCH 24/40] stage --- .../gorm/largeobject/large_objects.go | 49 ++++++------------- .../{utils_test.go => large_objects_test.go} | 0 2 files changed, 14 insertions(+), 35 deletions(-) rename pkg/postgres/gorm/largeobject/{utils_test.go => large_objects_test.go} (100%) diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index a79ea50b3698a..12ecb0f1bdb3c 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -7,9 +7,9 @@ import ( "gorm.io/gorm" ) -// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it -// was created. +// LargeObjects is used to access the large objects API with gorm CRM. // +// This is originally created with similar API with existing github.com/jackc/pgx // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html type LargeObjects struct { tx *gorm.DB @@ -24,7 +24,7 @@ const ( // Create creates a new large object with an unused OID assigned func (o *LargeObjects) Create() (uint32, error) { - o.tx = o.tx.Raw("SELECT lo_create(?)", 0) + o.tx = o.tx.Raw("SELECT lo_create($1)", 0) if err := o.tx.Error; err != nil { return 0, err } @@ -37,7 +37,7 @@ func (o *LargeObjects) Create() (uint32, error) { // object. func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { var fd int32 - o.tx = o.tx.Raw("select lo_open(?, ?)", oid, mode).Scan(&fd) + o.tx = o.tx.Raw("select lo_open($1, $2)", oid, mode).Scan(&fd) if err := o.tx.Error; err != nil { return nil, err } @@ -47,7 +47,7 @@ func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { // Unlink removes a large object from the database. func (o *LargeObjects) Unlink(oid uint32) error { var count int32 - o.tx = o.tx.Raw("select lo_unlink(?)", oid).Scan(&count) + o.tx = o.tx.Raw("select lo_unlink($1)", oid).Scan(&count) if err := o.tx.Error; err != nil { return err } @@ -103,7 +103,7 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { var n int - o.tx = o.tx.Raw("select lowrite(?, ?)", o.fd, p) + o.tx = o.tx.Raw("select lowrite($1, $2)", o.fd, p) if err := o.tx.Error; err != nil { return n, err } @@ -124,7 +124,7 @@ func (o *LargeObject) Write(p []byte) (int, error) { // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (n int, err error) { var res []byte = make([]byte, 0, len(p)) - o.tx = o.tx.Raw("select loread(?, ?)", o.fd, len(p)) + o.tx = o.tx.Raw("select loread($1, $2)", o.fd, len(p)) if err = o.tx.Error; err != nil { return 0, err } @@ -145,47 +145,26 @@ func (o *LargeObject) Read(p []byte) (n int, err error) { // Seek moves the current location pointer to the new location specified by offset. func (o *LargeObject) Seek(offset int64, whence int) (int64, error) { var n int64 - result := o.tx.Raw("select lo_lseek64(?, ?, ?)", o.fd, offset, whence) - if result.Error != nil { - return 0, o.tx.Error - } - row := o.tx.Row() - row.Scan(&n) - if result.Error != nil { - return 0, result.Error - } - return n, o.tx.Error + result := o.tx.Raw("select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + return n, result.Error } // Tell returns the current read or write location of the large object descriptor. func (o *LargeObject) Tell() (int64, error) { var n int64 - o.tx = o.tx.Raw("select lo_tell64($1)", o.fd).Scan(&n) - return n, o.tx.Error + result := o.tx.Raw("select lo_tell64($1)", o.fd).Scan(&n) + return n, result.Error } // Truncate the large object to size. func (o *LargeObject) Truncate(size int64) (n int, err error) { - result := o.tx.Raw("select lo_truncate64(?, ?)", o.fd, size).Scan(&n) + result := o.tx.Raw("select lo_truncate64($1, $2)", o.fd, size).Scan(&n) return n, result.Error } // Close the large object descriptor. func (o *LargeObject) Close() error { var n int - o.tx = o.tx.Raw("select lo_close(?)", o.fd).Scan(&n) - return o.tx.Error -} - -/* -{ var n int64 -o.tx = o.tx.Raw("select lo_tell64(?)", o.fd) -if o.tx.Error != nil { -return n, o.tx.Error -} -row := o.tx.Row() -if row.Err() != nil { -return n, row.Err() + result := o.tx.Raw("select lo_close($1)", o.fd).Scan(&n) + return result.Error } -err := row.Scan(&n) -return n, err}*/ diff --git a/pkg/postgres/gorm/largeobject/utils_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go similarity index 100% rename from pkg/postgres/gorm/largeobject/utils_test.go rename to pkg/postgres/gorm/largeobject/large_objects_test.go From 9a50c29ed3d9f6e618b7495f9d4f1bbf763a08fd Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 18:03:31 -0700 Subject: [PATCH 25/40] stage --- .../gorm/largeobject/large_objects.go | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index 12ecb0f1bdb3c..e52ed19895f8a 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -18,19 +18,20 @@ type LargeObjects struct { type Mode int32 const ( - ModeWrite Mode = 0x20000 - ModeRead Mode = 0x40000 + ModeWrite Mode = 0x20000 + ModeRead Mode = 0x40000 + ModeReadWrite Mode = ModeRead | ModeRead ) // Create creates a new large object with an unused OID assigned func (o *LargeObjects) Create() (uint32, error) { - o.tx = o.tx.Raw("SELECT lo_create($1)", 0) - if err := o.tx.Error; err != nil { + result := o.tx.Raw("SELECT lo_create($1)", 0) + if err := result.Error; err != nil { return 0, err } var oid uint32 - o.tx = o.tx.Scan(&oid) - return oid, o.tx.Error + result = result.Scan(&oid) + return oid, result.Error } // Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large @@ -103,14 +104,8 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { var n int - o.tx = o.tx.Raw("select lowrite($1, $2)", o.fd, p) - if err := o.tx.Error; err != nil { - return n, err - } - if err := o.tx.Row().Scan(&n); err != nil { - return n, err - } - if err := o.tx.Error; err != nil { + err := o.tx.Raw("select lowrite($1, $2)", o.fd, p).Row().Scan(&n) + if err != nil { return n, err } @@ -123,16 +118,10 @@ func (o *LargeObject) Write(p []byte) (int, error) { // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (n int, err error) { - var res []byte = make([]byte, 0, len(p)) - o.tx = o.tx.Raw("select loread($1, $2)", o.fd, len(p)) - if err = o.tx.Error; err != nil { - return 0, err - } - if err = o.tx.Row().Scan(&res); err != nil { - return 0, err - } + var res []byte + err = o.tx.Raw("select loread($1, $2)", o.fd, len(p)).Row().Scan(&res) copy(p, res) - if err = o.tx.Error; err != nil { + if err != nil { return len(res), err } From aa19928e7f37a42053833cabe1f925deb7664388 Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 18:59:48 -0700 Subject: [PATCH 26/40] stage --- .../gorm/largeobject/large_objects.go | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index e52ed19895f8a..e01f09263fe92 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -24,13 +24,8 @@ const ( ) // Create creates a new large object with an unused OID assigned -func (o *LargeObjects) Create() (uint32, error) { - result := o.tx.Raw("SELECT lo_create($1)", 0) - if err := result.Error; err != nil { - return 0, err - } - var oid uint32 - result = result.Scan(&oid) +func (o *LargeObjects) Create() (oid uint32, err error) { + result := o.tx.Raw("SELECT lo_create($1)", 0).Scan(&oid) return oid, result.Error } @@ -38,9 +33,9 @@ func (o *LargeObjects) Create() (uint32, error) { // object. func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { var fd int32 - o.tx = o.tx.Raw("select lo_open($1, $2)", oid, mode).Scan(&fd) - if err := o.tx.Error; err != nil { - return nil, err + result := o.tx.Raw("select lo_open($1, $2)", oid, mode).Scan(&fd) + if result.Error != nil { + return nil, result.Error } return &LargeObject{fd: fd, tx: o.tx}, nil } @@ -48,9 +43,9 @@ func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { // Unlink removes a large object from the database. func (o *LargeObjects) Unlink(oid uint32) error { var count int32 - o.tx = o.tx.Raw("select lo_unlink($1)", oid).Scan(&count) - if err := o.tx.Error; err != nil { - return err + result := o.tx.Raw("select lo_unlink($1)", oid).Scan(&count) + if result.Error != nil { + return result.Error } if count != 1 { return errors.New("failed to remove large object") @@ -58,26 +53,26 @@ func (o *LargeObjects) Unlink(oid uint32) error { return nil } +// Upsert insert a large object with oid. If the large object exists, +// replace it. func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { - obj, err := o.Open(oid, ModeWrite|ModeRead) + obj, err := o.Open(oid, ModeWrite) if err != nil { return err } + defer func() { + err = obj.Close() + }() _, err = obj.Truncate(0) if err != nil { return err } - obj.Close() - obj, err = o.Open(oid, ModeWrite) - if err != nil { - return err - } _, err = io.Copy(obj, r) if err != nil { return err } - return obj.Close() + return err } func (o *LargeObjects) Get(oid uint32, w io.Writer) error { @@ -85,12 +80,15 @@ func (o *LargeObjects) Get(oid uint32, w io.Writer) error { if err != nil { return err } + _, err = io.Copy(w, obj) - return err + if err != nil { + return obj.wrapClose(err) + } + return obj.wrapClose(err) } -// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized -// in. It uses the context it was initialized with for all operations. It implements these interfaces: +// A LargeObject implements the large object interface to Postgres database. It implements these interfaces: // // io.Writer // io.Reader @@ -145,7 +143,7 @@ func (o *LargeObject) Tell() (int64, error) { return n, result.Error } -// Truncate the large object to size. +// Truncate the large object to size and return the resulting size. func (o *LargeObject) Truncate(size int64) (n int, err error) { result := o.tx.Raw("select lo_truncate64($1, $2)", o.fd, size).Scan(&n) return n, result.Error @@ -157,3 +155,12 @@ func (o *LargeObject) Close() error { result := o.tx.Raw("select lo_close($1)", o.fd).Scan(&n) return result.Error } + +// wrapClose closes the large object and returns error if failed. Otherwise, it +// returns err +func (o *LargeObject) wrapClose(err error) error { + if closeErr := o.Close(); closeErr != nil { + return closeErr + } + return err +} From 853ed61c2cec325aef6db8becd06b426424a783a Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 19:07:09 -0700 Subject: [PATCH 27/40] stage --- .../gorm/largeobject/large_objects_test.go | 39 +++++++------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go index ac3b39d6f683e..69acf3f73e512 100644 --- a/pkg/postgres/gorm/largeobject/large_objects_test.go +++ b/pkg/postgres/gorm/largeobject/large_objects_test.go @@ -16,8 +16,9 @@ import ( type GormUtilsTestSuite struct { suite.Suite - db *pghelper.TestPostgres - ctx context.Context + db *pghelper.TestPostgres + ctx context.Context + gormDB *gorm.DB } func TestLargeObjects(t *testing.T) { @@ -27,6 +28,7 @@ func TestLargeObjects(t *testing.T) { func (s *GormUtilsTestSuite) SetupTest() { s.db = pghelper.ForT(s.T(), true) s.ctx = context.Background() + s.gormDB = s.db.GetGormDB().WithContext(s.ctx) } func (s *GormUtilsTestSuite) TearDownTest() { @@ -34,14 +36,12 @@ func (s *GormUtilsTestSuite) TearDownTest() { } func (s *GormUtilsTestSuite) TestUpsertGet() { - randomData := make([]byte, 100) + randomData := make([]byte, 90000) _, err := rand.Read(randomData) s.NoError(err) - gormDB := s.db.GetGormDB() - reader := bytes.NewBuffer(randomData) - tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + tx := s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los := LargeObjects{tx} oid, err := los.Create() s.Require().NoError(err) @@ -49,20 +49,20 @@ func (s *GormUtilsTestSuite) TestUpsertGet() { s.Require().NoError(err) s.Require().NoError(tx.Commit().Error) - tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} writer := bytes.NewBuffer([]byte{}) s.Require().NoError(los.Get(oid, writer)) s.Require().Equal(randomData, writer.Bytes()) reader = bytes.NewBuffer([]byte("hi")) - tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} err = los.Upsert(oid, reader) s.Require().NoError(err) s.Require().NoError(tx.Commit().Error) - tx = gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} writer = bytes.NewBuffer([]byte{}) writer.Reset() @@ -71,17 +71,9 @@ func (s *GormUtilsTestSuite) TestUpsertGet() { } func (s *GormUtilsTestSuite) TestLargeObject() { - ctx := context.Background() - - gormDB := s.db.GetGormDB().WithContext(ctx) - - tx := gormDB.Begin() + tx := s.gormDB.Begin() s.Require().NoError(tx.Error) - s.testLargeObject(tx) -} - -func (s *GormUtilsTestSuite) testLargeObject(tx *gorm.DB) { los := &LargeObjects{tx} id, err := los.Create() @@ -136,12 +128,7 @@ func (s *GormUtilsTestSuite) testLargeObject(tx *gorm.DB) { } func (s *GormUtilsTestSuite) TestLargeObjectsMultipleTransactions() { - ctx := context.Background() - - gormDB := s.db.GetGormDB().WithContext(ctx) - - tx := gormDB.Begin() - // tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + tx := s.gormDB.Begin() s.Require().NoError(tx.Error) los := &LargeObjects{tx} @@ -159,12 +146,12 @@ func (s *GormUtilsTestSuite) TestLargeObjectsMultipleTransactions() { // IMPORTANT: Use the same connection for another query query := `select n from generate_series(1,10) n` - rows, err := gormDB.Raw(query).Rows() + rows, err := s.gormDB.Raw(query).Rows() s.Require().NoError(err) rows.Close() // Start a new transaction - tx2 := gormDB.Begin() + tx2 := s.gormDB.Begin() // tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) s.Require().NoError(tx.Error) los2 := &LargeObjects{tx2} From 4aebc48bf904c8adc138441976e22100473cba2e Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 19:10:29 -0700 Subject: [PATCH 28/40] stage --- pkg/postgres/gorm/largeobject/large_objects_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go index 69acf3f73e512..9d910030f6286 100644 --- a/pkg/postgres/gorm/largeobject/large_objects_test.go +++ b/pkg/postgres/gorm/largeobject/large_objects_test.go @@ -8,7 +8,7 @@ import ( "io" "testing" - pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" + "github.com/stackrox/rox/pkg/postgres/pgtest" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -16,7 +16,7 @@ import ( type GormUtilsTestSuite struct { suite.Suite - db *pghelper.TestPostgres + db *pgtest.TestPostgres ctx context.Context gormDB *gorm.DB } @@ -26,9 +26,9 @@ func TestLargeObjects(t *testing.T) { } func (s *GormUtilsTestSuite) SetupTest() { - s.db = pghelper.ForT(s.T(), true) + s.db = pgtest.ForT(s.T()) s.ctx = context.Background() - s.gormDB = s.db.GetGormDB().WithContext(s.ctx) + s.gormDB = s.db.GetGormDB(s.T()).WithContext(s.ctx) } func (s *GormUtilsTestSuite) TearDownTest() { From 921254eb7aac879d7153ef38300ecc3ca4ab0ab9 Mon Sep 17 00:00:00 2001 From: cdu Date: Tue, 16 May 2023 19:14:58 -0700 Subject: [PATCH 29/40] stage --- pkg/postgres/gorm/largeobject/large_objects_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/postgres/gorm/largeobject/large_objects_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go index 9d910030f6286..b042f280e50c1 100644 --- a/pkg/postgres/gorm/largeobject/large_objects_test.go +++ b/pkg/postgres/gorm/largeobject/large_objects_test.go @@ -36,6 +36,7 @@ func (s *GormUtilsTestSuite) TearDownTest() { } func (s *GormUtilsTestSuite) TestUpsertGet() { + // Write a long file randomData := make([]byte, 90000) _, err := rand.Read(randomData) s.NoError(err) @@ -49,11 +50,14 @@ func (s *GormUtilsTestSuite) TestUpsertGet() { s.Require().NoError(err) s.Require().NoError(tx.Commit().Error) + // Read it back and verify tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} writer := bytes.NewBuffer([]byte{}) s.Require().NoError(los.Get(oid, writer)) + s.Require().NoError(tx.Commit().Error) + // Overwrite it s.Require().Equal(randomData, writer.Bytes()) reader = bytes.NewBuffer([]byte("hi")) tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) @@ -62,15 +66,17 @@ func (s *GormUtilsTestSuite) TestUpsertGet() { s.Require().NoError(err) s.Require().NoError(tx.Commit().Error) + // Read it back and verify tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) los = LargeObjects{tx} writer = bytes.NewBuffer([]byte{}) writer.Reset() s.Require().NoError(los.Get(oid, writer)) s.Require().Equal([]byte("hi"), writer.Bytes()) + s.Require().NoError(tx.Commit().Error) } -func (s *GormUtilsTestSuite) TestLargeObject() { +func (s *GormUtilsTestSuite) TestLargeObjectSingleTransaction() { tx := s.gormDB.Begin() s.Require().NoError(tx.Error) @@ -127,7 +133,7 @@ func (s *GormUtilsTestSuite) TestLargeObject() { s.Require().Contains(err.Error(), "does not exist (SQLSTATE 42704)") } -func (s *GormUtilsTestSuite) TestLargeObjectsMultipleTransactions() { +func (s *GormUtilsTestSuite) TestLargeObjectMultipleTransactions() { tx := s.gormDB.Begin() s.Require().NoError(tx.Error) los := &LargeObjects{tx} From 28b45d169deb54217fec03836656ca368737a5ec Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 17 May 2023 09:53:47 -0700 Subject: [PATCH 30/40] stage --- .../migration.go | 128 ++++++++------ .../migration_test.go | 164 +++++++----------- .../gorm/largeobject/large_objects.go | 15 +- 3 files changed, 145 insertions(+), 162 deletions(-) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go index 535077e175b9f..6fcf38125c5ab 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go @@ -5,19 +5,25 @@ import ( "database/sql" "os" + timestamp "github.com/gogo/protobuf/types" "github.com/pkg/errors" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/migrator/migrations" "github.com/stackrox/rox/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema" "github.com/stackrox/rox/migrator/types" "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/pkg/postgres/gorm/largeobject" + "github.com/stackrox/rox/pkg/postgres/pgutils" "github.com/stackrox/rox/pkg/sac" "gorm.io/gorm" ) const ( scannerDefBlobName = "/offline/scanner/scanner-defs.zip" - scannerDefPath = "/var/lib/stackrox/scannerdefinitions/scanner-defs.zip" +) + +var ( + scannerDefPath = "/var/lib/stackrox/scannerdefinitions/scanner-defs.zip" ) var ( @@ -27,78 +33,90 @@ var ( Run: func(databases *types.Databases) error { err := moveToBlobs(databases.GormDB) if err != nil { - return errors.Wrap(err, "updating policies") + return errors.Wrap(err, "moving persistent files to blobs") } return nil }, } - log = logging.LoggerForModule() - toBeMigrated = map[string]string{ - scannerDefPath: scannerDefBlobName, - } + log = logging.LoggerForModule() ) func moveToBlobs(db *gorm.DB) (err error) { ctx := sac.WithAllAccess(context.Background()) db = db.WithContext(ctx).Table(schema.BlobsTableName) - if err := db.WithContext(ctx).AutoMigrate(schema.CreateTableBlobsStmt.GormModel); err != nil { + pgutils.CreateTableFromModel(context.Background(), db, schema.CreateTableBlobsStmt) + + tx := db.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + if err = moveScannerDefination(tx); err != nil { + result := tx.Rollback() + if result.Error != nil { + return result.Error + } return err } - tx := db.Model(schema.CreateTableBlobsStmt.GormModel).Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - defer func() { - if err != nil { - tx.Rollback() - } - }() - for p, blobName := range toBeMigrated { - f, err := os.Open(p) - if errors.Is(err, os.ErrNotExist) { - continue + return tx.Commit().Error +} + +func moveScannerDefination(tx *gorm.DB) error { + stat, err := os.Stat(scannerDefPath) + if err != nil { + if os.IsNotExist(err) || stat.IsDir() { + return nil } + return err + } + modTime, err := timestamp.TimestampProto(stat.ModTime()) + if err != nil { + return errors.Wrapf(err, "invalid timestamp %v", stat.ModTime()) + } + fd, err := os.Open(scannerDefPath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return errors.Wrapf(err, "failed to open %s", scannerDefPath) + } + + // Prepare blob + blob := &storage.Blob{ + Name: scannerDefBlobName, + Oid: 0, + Length: stat.Size(), + LastUpdated: timestamp.TimestampNow(), + ModifiedTime: modTime, + } + los := largeobject.LargeObjects{DB: tx} + + // Find the blob if it exists + var targets []schema.Blobs + result := tx.Limit(1).Where(&schema.Blobs{Name: scannerDefBlobName}).Find(&targets) + if result.Error != nil { + return result.Error + } + + if len(targets) == 0 { + blob.Oid, err = los.Create() if err != nil { - return err - } - target := &schema.Blobs{Name: blobName} - result := tx.Take(target) - if result.Error != nil { - return result.Error + return errors.Wrap(err, "failed to create large object") } - var blob *storage.Blob - if result.RowsAffected == 0 { - // Create - // err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid) - var oid int - tx.Select("lo_create(0)").Find() - tx.Exec("SELECT lo_create(0)").Find(&oid) - blob = &storage.Blob{ - Name: blobName, - Oid: 0, - Length: 0, - ModifiedTime: nil, - } - } else { - // Update - existingBlob, err := schema.ConvertBlobToProto(target) - if err != nil { - return err - } - blob = &storage.Blob{ - Name: blobName, - Oid: existingBlob.Oid, - Length: 0, - ModifiedTime: nil, - } - } - blobModel, err := schema.ConvertBlobFromProto(blob) + } else { + // Update + existingBlob, err := schema.ConvertBlobToProto(&targets[0]) if err != nil { - return err + return errors.Wrapf(err, "existing blob is not valid %+v", targets[0]) } - tx.Exec("") - tx = tx.FirstOrCreate(blobModel) - + blob.Oid = existingBlob.Oid } - return tx.Commit().Error + blobModel, err := schema.ConvertBlobFromProto(blob) + if err != nil { + return errors.Wrapf(err, "failed to convert blob to blob model %+v", blob) + } + tx = tx.FirstOrCreate(blobModel) + if tx.Error != nil { + return errors.Wrap(tx.Error, "failed to create blob metadata") + } + return los.Upsert(blob.Oid, fd) } func init() { diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go index 5b8a68f457b39..020adfc2da288 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -1,30 +1,23 @@ -//go:build sql_integration - package m179tom180 import ( - "context" - "fmt" + "bytes" + "crypto/rand" + "io" + "os" "testing" - "github.com/stackrox/rox/generated/storage" - frozenSchema "github.com/stackrox/rox/migrator/migrations/frozenschema/v73" - policyPostgresStore "github.com/stackrox/rox/migrator/migrations/m_179_to_m_180_openshift_policy_exclusions/postgres" + "github.com/stackrox/rox/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema" pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" - "github.com/stackrox/rox/migrator/types" - "github.com/stackrox/rox/pkg/fixtures" + "github.com/stackrox/rox/pkg/postgres/gorm/largeobject" "github.com/stackrox/rox/pkg/postgres/pgutils" - "github.com/stackrox/rox/pkg/sac" - "github.com/stackrox/rox/pkg/search" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) type categoriesMigrationTestSuite struct { suite.Suite - db *pghelper.TestPostgres - policyStore policyPostgresStore.Store + db *pghelper.TestPostgres } func TestMigration(t *testing.T) { @@ -33,9 +26,6 @@ func TestMigration(t *testing.T) { func (s *categoriesMigrationTestSuite) SetupTest() { s.db = pghelper.ForT(s.T(), true) - s.policyStore = policyPostgresStore.New(s.db.DB) - pgutils.CreateTableFromModel(context.Background(), s.db.GetGormDB(), frozenSchema.CreateTablePoliciesStmt) - } func (s *categoriesMigrationTestSuite) TearDownTest() { @@ -43,87 +33,63 @@ func (s *categoriesMigrationTestSuite) TearDownTest() { } func (s *categoriesMigrationTestSuite) TestMigration() { - ctx := sac.WithAllAccess(context.Background()) - testPolicy := fixtures.GetPolicy() - testPolicy.Id = "ed8c7957-14de-40bc-aeab-d27ceeecfa7b" - testPolicy.Name = "Iptables Executed in Privileged Container" - testPolicy.Description = "Alert on privileged pods that execute iptables" - testPolicy.PolicySections = []*storage.PolicySection{ - { - PolicyGroups: []*storage.PolicyGroup{ - { - FieldName: "Privileged Container", - Values: []*storage.PolicyValue{ - { - Value: "true", - }, - }, - }, - { - FieldName: "Process Name", - Values: []*storage.PolicyValue{ - { - Value: "iptables", - }, - }, - }, - { - FieldName: "Process UID", - Values: []*storage.PolicyValue{ - { - Value: "0", - }, - }, - }, - }, - }, - } - require.NoError(s.T(), s.policyStore.Upsert(ctx, testPolicy)) - // insert other policies in db for migration to run successfully - policies := []string{ - "fb8f8732-c31d-496b-8fb1-d5abe6056e27", - "880fd131-46f0-43d2-82c9-547f5aa7e043", - "47cb9e0a-879a-417b-9a8f-de644d7c8a77", - "6226d4ad-7619-4a0b-a160-46373cfcee66", - "436811e7-892f-4da6-a0f5-8cc459f1b954", - "742e0361-bddd-4a2d-8758-f2af6197f61d", - "16c95922-08c4-41b6-a721-dc4b2a806632", - "fe9de18b-86db-44d5-a7c4-74173ccffe2e", - "dce17697-1b72-49d2-b18a-05d893cd9368", - "f4996314-c3d7-4553-803b-b24ce7febe48", - "a9b9ecf7-9707-4e32-8b62-d03018ed454f", - "32d770b9-c6ba-4398-b48a-0c3e807644ed", - "f95ff08d-130a-465a-a27e-32ed1fb05555", - } - - policyName := "policy description %d" - for i := 0; i < len(policies); i++ { - require.NoError(s.T(), s.policyStore.Upsert(ctx, &storage.Policy{ - Id: policies[i], - Name: fmt.Sprintf(policyName, i), - })) - } - dbs := &types.Databases{ - PostgresDB: s.db.DB, - GormDB: s.db.GetGormDB(), - } - - q := search.NewQueryBuilder().AddExactMatches(search.PolicyID, testPolicy.GetId()).ProtoQuery() - policyPremigration, err := s.policyStore.GetByQuery(ctx, q) - s.NoError(err) - s.Empty(policyPremigration[0].Exclusions) - s.NoError(migration.Run(dbs)) - expectedExclusions := []string{"Don't alert on ovnkube-node deployment in openshift-ovn-kubernetes Namespace", - "Don't alert on haproxy-* deployment in openshift-vsphere-infra namespace", - "Don't alert on keepalived-* deployment in openshift-vsphere-infra namespace", - "Don't alert on coredns-* deployment in openshift-vsphere-infra namespace"} - query := search.NewQueryBuilder().AddExactMatches(search.PolicyID, testPolicy.GetId()).ProtoQuery() - policy, err := s.policyStore.GetByQuery(ctx, query) - s.NoError(err) - var actualExclusions []string - for _, excl := range policy[0].Exclusions { - actualExclusions = append(actualExclusions, excl.Name) - } - s.ElementsMatch(actualExclusions, expectedExclusions, "exclusion do not match after migration") + // Nothing to migrate + s.Require().NoError(moveToBlobs(s.db.GetGormDB())) + + // Prepare persistent file + size := 90000 + randomData := make([]byte, size) + _, err := rand.Read(randomData) + s.Require().NoError(err) + reader := bytes.NewBuffer(randomData) + + file, err := os.CreateTemp("", "move-blob") + defer func() { + s.NoError(file.Close()) + s.NoError(os.Remove(file.Name())) + }() + scannerDefPath = file.Name() + n, err := io.Copy(file, reader) + s.Require().NoError(err) + + s.Require().EqualValues(size, n) + + // Migrate + s.Require().NoError(moveToBlobs(s.db.GetGormDB())) + + // Verify Blob + blobModel := &schema.Blobs{Name: scannerDefBlobName} + s.Require().NoError(s.db.GetGormDB().First(&blobModel).Error) + + blob, err := schema.ConvertBlobToProto(blobModel) + s.Require().NoError(err) + s.Equal(scannerDefBlobName, blob.GetName()) + s.EqualValues(size, blob.GetLength()) + + fileInfo, err := file.Stat() + s.Require().NoError(err) + modTime := pgutils.NilOrTime(blob.GetModifiedTime()) + s.Equal(fileInfo.ModTime().UTC(), modTime.UTC()) + + // Verify Data + buf := bytes.NewBuffer([]byte{}) + + tx := s.db.GetGormDB().Begin() + s.Require().NoError(err) + los := &largeobject.LargeObjects{DB: tx} + s.Require().NoError(los.Get(blob.Oid, buf)) + s.Equal(len(randomData), buf.Len()) + s.Equal(randomData, buf.Bytes()) + s.NoError(tx.Commit().Error) + // Test re-entry + s.Require().NoError(moveToBlobs(s.db.GetGormDB())) + buf.Reset() + tx = s.db.GetGormDB().Begin() + los = &largeobject.LargeObjects{DB: tx} + s.Require().NoError(err) + s.Require().NoError(los.Get(blob.Oid, buf)) + s.Equal(len(randomData), buf.Len()) + s.Equal(randomData, buf.Bytes()) + s.NoError(tx.Commit().Error) } diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index e01f09263fe92..262d808ef02fe 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -12,20 +12,19 @@ import ( // This is originally created with similar API with existing github.com/jackc/pgx // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html type LargeObjects struct { - tx *gorm.DB + *gorm.DB } type Mode int32 const ( - ModeWrite Mode = 0x20000 - ModeRead Mode = 0x40000 - ModeReadWrite Mode = ModeRead | ModeRead + ModeWrite Mode = 0x20000 + ModeRead Mode = 0x40000 ) // Create creates a new large object with an unused OID assigned func (o *LargeObjects) Create() (oid uint32, err error) { - result := o.tx.Raw("SELECT lo_create($1)", 0).Scan(&oid) + result := o.Raw("SELECT lo_create($1)", 0).Scan(&oid) return oid, result.Error } @@ -33,17 +32,17 @@ func (o *LargeObjects) Create() (oid uint32, err error) { // object. func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { var fd int32 - result := o.tx.Raw("select lo_open($1, $2)", oid, mode).Scan(&fd) + result := o.Raw("select lo_open($1, $2)", oid, mode).Scan(&fd) if result.Error != nil { return nil, result.Error } - return &LargeObject{fd: fd, tx: o.tx}, nil + return &LargeObject{fd: fd, tx: o.DB}, nil } // Unlink removes a large object from the database. func (o *LargeObjects) Unlink(oid uint32) error { var count int32 - result := o.tx.Raw("select lo_unlink($1)", oid).Scan(&count) + result := o.Raw("select lo_unlink($1)", oid).Scan(&count) if result.Error != nil { return result.Error } From c9c7c4e9562d226d41511cc5c77adb148f0adb93 Mon Sep 17 00:00:00 2001 From: cdu Date: Mon, 15 May 2023 11:38:28 -0700 Subject: [PATCH 31/40] Migrate persistent data #1 - scanner definition --- central/probeupload/manager/manager_impl.go | 1 + central/scannerdefinitions/handler/handler.go | 10 +- .../handler/handler_test.go | 6 +- .../migration.go | 124 +++++++++++ .../migration_test.go | 95 ++++++++ .../schema/blobs.go | 35 +++ .../schema/convert_blobs.go | 28 +++ .../schema/convert_blobs_test.go | 20 ++ .../schema/gen.go | 3 + .../gorm/largeobject/large_objects.go | 165 ++++++++++++++ .../gorm/largeobject/large_objects_test.go | 208 ++++++++++++++++++ 11 files changed, 689 insertions(+), 6 deletions(-) create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go create mode 100644 migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go create mode 100644 pkg/postgres/gorm/largeobject/large_objects.go create mode 100644 pkg/postgres/gorm/largeobject/large_objects_test.go diff --git a/central/probeupload/manager/manager_impl.go b/central/probeupload/manager/manager_impl.go index da22f0b6a55d2..6462785aa042a 100644 --- a/central/probeupload/manager/manager_impl.go +++ b/central/probeupload/manager/manager_impl.go @@ -49,6 +49,7 @@ type manager struct { func newManager(persistenceRoot string) *manager { return &manager{ + rootDir: filepath.Join(persistenceRoot, rootDirName), freeDiskThreshold: defaultFreeDiskThreshold, } } diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index 7e6a33db5d188..c78e28e2e5f27 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -26,6 +26,7 @@ import ( "github.com/stackrox/rox/pkg/httputil/proxy" "github.com/stackrox/rox/pkg/logging" "github.com/stackrox/rox/pkg/postgres/pgutils" + "github.com/stackrox/rox/pkg/sac" "github.com/stackrox/rox/pkg/sync" "github.com/stackrox/rox/pkg/utils" "google.golang.org/grpc/codes" @@ -221,7 +222,7 @@ func (h *httpHandler) handleScannerDefsFile(ctx context.Context, zipF *zip.File) Length: zipF.FileInfo().Size(), } - if err := h.blobStore.Upsert(ctx, b, r); err != nil { + if err := h.blobStore.Upsert(sac.WithAllAccess(ctx), b, r); err != nil { return errors.Wrap(err, "writing scanner definitions") } @@ -313,12 +314,14 @@ func (h *httpHandler) cleanupUpdaters(cleanupAge time.Duration) { } func (h *httpHandler) openOfflineBlob(ctx context.Context) (*vulDefFile, error) { - snap, err := snapshot.TakeBlobSnapshot(ctx, h.blobStore, offlineScannerDefsName) + snap, err := snapshot.TakeBlobSnapshot(sac.WithAllAccess(ctx), h.blobStore, offlineScannerDefsName) if err != nil { // If the blob does not exist, return no reader. if errors.Is(err, snapshot.ErrBlobNotExist) { + log.Warnf("Blob %s does not exist", offlineScannerDefsName) return nil, nil } + log.Warnf("Cannnot take a snapshot of Blob %q: %v", offlineScannerDefsName, err) return nil, err } modTime := time.Time{} @@ -336,8 +339,7 @@ func (h *httpHandler) openOfflineBlob(ctx context.Context) (*vulDefFile, error) func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string) (file *vulDefFile, err error) { // If in offline mode or uuid is not provided, default to the offline file. if !h.online || uuid == "" { - file, err = h.openOfflineBlob(ctx) - return + return h.openOfflineBlob(ctx) } // Start the updater, can be called multiple times for the same uuid, but will diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index 4553eedb53751..8edc6ca12ff23 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -61,8 +61,10 @@ func (s *handlerTestSuite) SetupTest() { func (s *handlerTestSuite) TearDownSuite() { entries, err := os.ReadDir(s.tmpDir) s.NoError(err) - s.Len(entries, 1) - s.True(strings.HasPrefix(entries[0].Name(), definitionsBaseDir)) + s.Less(len(entries), 1) + if len(entries) == 1 { + s.True(strings.HasPrefix(entries[0].Name(), definitionsBaseDir)) + } s.testDB.Teardown(s.T()) utils.IgnoreError(func() error { return os.RemoveAll(s.tmpDir) }) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go new file mode 100644 index 0000000000000..6fcf38125c5ab --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go @@ -0,0 +1,124 @@ +package m179tom180 + +import ( + "context" + "database/sql" + "os" + + timestamp "github.com/gogo/protobuf/types" + "github.com/pkg/errors" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/migrator/migrations" + "github.com/stackrox/rox/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema" + "github.com/stackrox/rox/migrator/types" + "github.com/stackrox/rox/pkg/logging" + "github.com/stackrox/rox/pkg/postgres/gorm/largeobject" + "github.com/stackrox/rox/pkg/postgres/pgutils" + "github.com/stackrox/rox/pkg/sac" + "gorm.io/gorm" +) + +const ( + scannerDefBlobName = "/offline/scanner/scanner-defs.zip" +) + +var ( + scannerDefPath = "/var/lib/stackrox/scannerdefinitions/scanner-defs.zip" +) + +var ( + migration = types.Migration{ + StartingSeqNum: 180, + VersionAfter: &storage.Version{SeqNum: 181}, + Run: func(databases *types.Databases) error { + err := moveToBlobs(databases.GormDB) + if err != nil { + return errors.Wrap(err, "moving persistent files to blobs") + } + return nil + }, + } + log = logging.LoggerForModule() +) + +func moveToBlobs(db *gorm.DB) (err error) { + ctx := sac.WithAllAccess(context.Background()) + db = db.WithContext(ctx).Table(schema.BlobsTableName) + pgutils.CreateTableFromModel(context.Background(), db, schema.CreateTableBlobsStmt) + + tx := db.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + if err = moveScannerDefination(tx); err != nil { + result := tx.Rollback() + if result.Error != nil { + return result.Error + } + return err + } + + return tx.Commit().Error +} + +func moveScannerDefination(tx *gorm.DB) error { + stat, err := os.Stat(scannerDefPath) + if err != nil { + if os.IsNotExist(err) || stat.IsDir() { + return nil + } + return err + } + modTime, err := timestamp.TimestampProto(stat.ModTime()) + if err != nil { + return errors.Wrapf(err, "invalid timestamp %v", stat.ModTime()) + } + fd, err := os.Open(scannerDefPath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return errors.Wrapf(err, "failed to open %s", scannerDefPath) + } + + // Prepare blob + blob := &storage.Blob{ + Name: scannerDefBlobName, + Oid: 0, + Length: stat.Size(), + LastUpdated: timestamp.TimestampNow(), + ModifiedTime: modTime, + } + los := largeobject.LargeObjects{DB: tx} + + // Find the blob if it exists + var targets []schema.Blobs + result := tx.Limit(1).Where(&schema.Blobs{Name: scannerDefBlobName}).Find(&targets) + if result.Error != nil { + return result.Error + } + + if len(targets) == 0 { + blob.Oid, err = los.Create() + if err != nil { + return errors.Wrap(err, "failed to create large object") + } + } else { + // Update + existingBlob, err := schema.ConvertBlobToProto(&targets[0]) + if err != nil { + return errors.Wrapf(err, "existing blob is not valid %+v", targets[0]) + } + blob.Oid = existingBlob.Oid + } + blobModel, err := schema.ConvertBlobFromProto(blob) + if err != nil { + return errors.Wrapf(err, "failed to convert blob to blob model %+v", blob) + } + tx = tx.FirstOrCreate(blobModel) + if tx.Error != nil { + return errors.Wrap(tx.Error, "failed to create blob metadata") + } + return los.Upsert(blob.Oid, fd) +} + +func init() { + migrations.MustRegisterMigration(migration) +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go new file mode 100644 index 0000000000000..020adfc2da288 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -0,0 +1,95 @@ +package m179tom180 + +import ( + "bytes" + "crypto/rand" + "io" + "os" + "testing" + + "github.com/stackrox/rox/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema" + pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper" + "github.com/stackrox/rox/pkg/postgres/gorm/largeobject" + "github.com/stackrox/rox/pkg/postgres/pgutils" + "github.com/stretchr/testify/suite" +) + +type categoriesMigrationTestSuite struct { + suite.Suite + + db *pghelper.TestPostgres +} + +func TestMigration(t *testing.T) { + suite.Run(t, new(categoriesMigrationTestSuite)) +} + +func (s *categoriesMigrationTestSuite) SetupTest() { + s.db = pghelper.ForT(s.T(), true) +} + +func (s *categoriesMigrationTestSuite) TearDownTest() { + s.db.Teardown(s.T()) +} + +func (s *categoriesMigrationTestSuite) TestMigration() { + // Nothing to migrate + s.Require().NoError(moveToBlobs(s.db.GetGormDB())) + + // Prepare persistent file + size := 90000 + randomData := make([]byte, size) + _, err := rand.Read(randomData) + s.Require().NoError(err) + reader := bytes.NewBuffer(randomData) + + file, err := os.CreateTemp("", "move-blob") + defer func() { + s.NoError(file.Close()) + s.NoError(os.Remove(file.Name())) + }() + scannerDefPath = file.Name() + n, err := io.Copy(file, reader) + s.Require().NoError(err) + + s.Require().EqualValues(size, n) + + // Migrate + s.Require().NoError(moveToBlobs(s.db.GetGormDB())) + + // Verify Blob + blobModel := &schema.Blobs{Name: scannerDefBlobName} + s.Require().NoError(s.db.GetGormDB().First(&blobModel).Error) + + blob, err := schema.ConvertBlobToProto(blobModel) + s.Require().NoError(err) + s.Equal(scannerDefBlobName, blob.GetName()) + s.EqualValues(size, blob.GetLength()) + + fileInfo, err := file.Stat() + s.Require().NoError(err) + modTime := pgutils.NilOrTime(blob.GetModifiedTime()) + s.Equal(fileInfo.ModTime().UTC(), modTime.UTC()) + + // Verify Data + buf := bytes.NewBuffer([]byte{}) + + tx := s.db.GetGormDB().Begin() + s.Require().NoError(err) + los := &largeobject.LargeObjects{DB: tx} + s.Require().NoError(los.Get(blob.Oid, buf)) + s.Equal(len(randomData), buf.Len()) + s.Equal(randomData, buf.Bytes()) + s.NoError(tx.Commit().Error) + + // Test re-entry + s.Require().NoError(moveToBlobs(s.db.GetGormDB())) + buf.Reset() + tx = s.db.GetGormDB().Begin() + los = &largeobject.LargeObjects{DB: tx} + s.Require().NoError(err) + s.Require().NoError(los.Get(blob.Oid, buf)) + s.Equal(len(randomData), buf.Len()) + s.Equal(randomData, buf.Bytes()) + s.NoError(tx.Commit().Error) +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go new file mode 100644 index 0000000000000..c6401dcf1b832 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/blobs.go @@ -0,0 +1,35 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. + +package schema + +import ( + "reflect" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/postgres" + "github.com/stackrox/rox/pkg/postgres/walker" +) + +var ( + // CreateTableBlobsStmt holds the create statement for table `blobs`. + CreateTableBlobsStmt = &postgres.CreateStmts{ + GormModel: (*Blobs)(nil), + Children: []*postgres.CreateStmts{}, + } + + // BlobsSchema is the go schema for table `blobs`. + BlobsSchema = func() *walker.Schema { + schema := walker.Walk(reflect.TypeOf((*storage.Blob)(nil)), "blobs") + return schema + }() +) + +const ( + BlobsTableName = "blobs" +) + +// Blobs holds the Gorm model for Postgres table `blobs`. +type Blobs struct { + Name string `gorm:"column:name;type:varchar;primaryKey"` + Serialized []byte `gorm:"column:serialized;type:bytea"` +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go new file mode 100644 index 0000000000000..25223842c6131 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs.go @@ -0,0 +1,28 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. +package schema + +import ( + "github.com/stackrox/rox/generated/storage" +) + +// ConvertBlobFromProto converts a `*storage.Blob` to Gorm model +func ConvertBlobFromProto(obj *storage.Blob) (*Blobs, error) { + serialized, err := obj.Marshal() + if err != nil { + return nil, err + } + model := &Blobs{ + Name: obj.GetName(), + Serialized: serialized, + } + return model, nil +} + +// ConvertBlobToProto converts Gorm model `Blobs` to its protobuf type object +func ConvertBlobToProto(m *Blobs) (*storage.Blob, error) { + var msg storage.Blob + if err := msg.Unmarshal(m.Serialized); err != nil { + return nil, err + } + return &msg, nil +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go new file mode 100644 index 0000000000000..a2f2300ee3c51 --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/convert_blobs_test.go @@ -0,0 +1,20 @@ +// Code generated by pg-bindings generator. DO NOT EDIT. +package schema + +import ( + "testing" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/rox/pkg/testutils" + "github.com/stretchr/testify/assert" +) + +func TestBlobSerialization(t *testing.T) { + obj := &storage.Blob{} + assert.NoError(t, testutils.FullInit(obj, testutils.UniqueInitializer(), testutils.JSONFieldsFilter)) + m, err := ConvertBlobFromProto(obj) + assert.NoError(t, err) + conv, err := ConvertBlobToProto(m) + assert.NoError(t, err) + assert.Equal(t, obj, conv) +} diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go new file mode 100644 index 0000000000000..8265e535a968e --- /dev/null +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go @@ -0,0 +1,3 @@ +package schema + +//go:generate pg-schema-migration-helper --type=storage.Blob diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go new file mode 100644 index 0000000000000..262d808ef02fe --- /dev/null +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -0,0 +1,165 @@ +package largeobject + +import ( + "errors" + "io" + + "gorm.io/gorm" +) + +// LargeObjects is used to access the large objects API with gorm CRM. +// +// This is originally created with similar API with existing github.com/jackc/pgx +// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html +type LargeObjects struct { + *gorm.DB +} + +type Mode int32 + +const ( + ModeWrite Mode = 0x20000 + ModeRead Mode = 0x40000 +) + +// Create creates a new large object with an unused OID assigned +func (o *LargeObjects) Create() (oid uint32, err error) { + result := o.Raw("SELECT lo_create($1)", 0).Scan(&oid) + return oid, result.Error +} + +// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large +// object. +func (o *LargeObjects) Open(oid uint32, mode Mode) (*LargeObject, error) { + var fd int32 + result := o.Raw("select lo_open($1, $2)", oid, mode).Scan(&fd) + if result.Error != nil { + return nil, result.Error + } + return &LargeObject{fd: fd, tx: o.DB}, nil +} + +// Unlink removes a large object from the database. +func (o *LargeObjects) Unlink(oid uint32) error { + var count int32 + result := o.Raw("select lo_unlink($1)", oid).Scan(&count) + if result.Error != nil { + return result.Error + } + if count != 1 { + return errors.New("failed to remove large object") + } + return nil +} + +// Upsert insert a large object with oid. If the large object exists, +// replace it. +func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { + obj, err := o.Open(oid, ModeWrite) + if err != nil { + return err + } + defer func() { + err = obj.Close() + }() + _, err = obj.Truncate(0) + if err != nil { + return err + } + _, err = io.Copy(obj, r) + if err != nil { + return err + } + + return err +} + +func (o *LargeObjects) Get(oid uint32, w io.Writer) error { + obj, err := o.Open(oid, ModeRead) + if err != nil { + return err + } + + _, err = io.Copy(w, obj) + if err != nil { + return obj.wrapClose(err) + } + return obj.wrapClose(err) +} + +// A LargeObject implements the large object interface to Postgres database. It implements these interfaces: +// +// io.Writer +// io.Reader +// io.Seeker +// io.Closer +type LargeObject struct { + tx *gorm.DB + fd int32 +} + +// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. +func (o *LargeObject) Write(p []byte) (int, error) { + var n int + err := o.tx.Raw("select lowrite($1, $2)", o.fd, p).Row().Scan(&n) + if err != nil { + return n, err + } + + if n < 0 { + return 0, errors.New("failed to write to large object") + } + + return n, nil +} + +// Read reads up to len(p) bytes into p returning the number of bytes read. +func (o *LargeObject) Read(p []byte) (n int, err error) { + var res []byte + err = o.tx.Raw("select loread($1, $2)", o.fd, len(p)).Row().Scan(&res) + copy(p, res) + if err != nil { + return len(res), err + } + + if len(res) < len(p) { + err = io.EOF + } + return len(res), err +} + +// Seek moves the current location pointer to the new location specified by offset. +func (o *LargeObject) Seek(offset int64, whence int) (int64, error) { + var n int64 + result := o.tx.Raw("select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + return n, result.Error +} + +// Tell returns the current read or write location of the large object descriptor. +func (o *LargeObject) Tell() (int64, error) { + var n int64 + result := o.tx.Raw("select lo_tell64($1)", o.fd).Scan(&n) + return n, result.Error +} + +// Truncate the large object to size and return the resulting size. +func (o *LargeObject) Truncate(size int64) (n int, err error) { + result := o.tx.Raw("select lo_truncate64($1, $2)", o.fd, size).Scan(&n) + return n, result.Error +} + +// Close the large object descriptor. +func (o *LargeObject) Close() error { + var n int + result := o.tx.Raw("select lo_close($1)", o.fd).Scan(&n) + return result.Error +} + +// wrapClose closes the large object and returns error if failed. Otherwise, it +// returns err +func (o *LargeObject) wrapClose(err error) error { + if closeErr := o.Close(); closeErr != nil { + return closeErr + } + return err +} diff --git a/pkg/postgres/gorm/largeobject/large_objects_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go new file mode 100644 index 0000000000000..b042f280e50c1 --- /dev/null +++ b/pkg/postgres/gorm/largeobject/large_objects_test.go @@ -0,0 +1,208 @@ +package largeobject + +import ( + "bytes" + "context" + "crypto/rand" + "database/sql" + "io" + "testing" + + "github.com/stackrox/rox/pkg/postgres/pgtest" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type GormUtilsTestSuite struct { + suite.Suite + + db *pgtest.TestPostgres + ctx context.Context + gormDB *gorm.DB +} + +func TestLargeObjects(t *testing.T) { + suite.Run(t, new(GormUtilsTestSuite)) +} + +func (s *GormUtilsTestSuite) SetupTest() { + s.db = pgtest.ForT(s.T()) + s.ctx = context.Background() + s.gormDB = s.db.GetGormDB(s.T()).WithContext(s.ctx) +} + +func (s *GormUtilsTestSuite) TearDownTest() { + s.db.Teardown(s.T()) +} + +func (s *GormUtilsTestSuite) TestUpsertGet() { + // Write a long file + randomData := make([]byte, 90000) + _, err := rand.Read(randomData) + s.NoError(err) + + reader := bytes.NewBuffer(randomData) + tx := s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los := LargeObjects{tx} + oid, err := los.Create() + s.Require().NoError(err) + err = los.Upsert(oid, reader) + s.Require().NoError(err) + s.Require().NoError(tx.Commit().Error) + + // Read it back and verify + tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los = LargeObjects{tx} + writer := bytes.NewBuffer([]byte{}) + s.Require().NoError(los.Get(oid, writer)) + s.Require().NoError(tx.Commit().Error) + + // Overwrite it + s.Require().Equal(randomData, writer.Bytes()) + reader = bytes.NewBuffer([]byte("hi")) + tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los = LargeObjects{tx} + err = los.Upsert(oid, reader) + s.Require().NoError(err) + s.Require().NoError(tx.Commit().Error) + + // Read it back and verify + tx = s.gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + los = LargeObjects{tx} + writer = bytes.NewBuffer([]byte{}) + writer.Reset() + s.Require().NoError(los.Get(oid, writer)) + s.Require().Equal([]byte("hi"), writer.Bytes()) + s.Require().NoError(tx.Commit().Error) +} + +func (s *GormUtilsTestSuite) TestLargeObjectSingleTransaction() { + tx := s.gormDB.Begin() + s.Require().NoError(tx.Error) + + los := &LargeObjects{tx} + + id, err := los.Create() + s.Require().NoError(err) + + obj, err := los.Open(id, ModeWrite|ModeRead) + s.Require().NoError(err) + + n, err := obj.Write([]byte("testing")) + s.Require().NoError(err) + s.Require().Equal(7, n, "Expected n to be 7, got %d", n) + + pos, err := obj.Seek(1, 0) + s.Require().NoError(err) + s.Require().Equal(int64(1), pos) + + res := make([]byte, 6) + n, err = obj.Read(res) + s.Require().NoError(err) + s.Require().Equal("esting", string(res)) + s.Require().Equal(6, n) + + n, err = obj.Read(res) + s.Require().Equal(err, io.EOF) + s.Require().Zero(n) + + pos, err = obj.Tell() + s.Require().NoError(err) + s.Require().EqualValues(7, pos) + + n, err = obj.Truncate(1) + s.Require().NoError(err) + + pos, err = obj.Seek(-1, 2) + s.Require().NoError(err) + s.Require().Zero(pos) + + res = make([]byte, 2) + n, err = obj.Read(res) + s.Require().Equal(io.EOF, err) + s.Require().Equal(1, n) + s.Require().EqualValues('t', res[0]) + + err = obj.Close() + s.Require().NoError(err) + + err = los.Unlink(id) + s.Require().NoError(err) + + _, err = los.Open(id, ModeRead) + s.Require().Contains(err.Error(), "does not exist (SQLSTATE 42704)") +} + +func (s *GormUtilsTestSuite) TestLargeObjectMultipleTransactions() { + tx := s.gormDB.Begin() + s.Require().NoError(tx.Error) + los := &LargeObjects{tx} + + id, err := los.Create() + s.Require().NoError(err) + obj, err := los.Open(id, ModeWrite|ModeRead) + s.Require().NoError(err) + + n, err := obj.Write([]byte("testing")) + s.Require().NoError(err) + s.Require().Equal(7, n, "Expected n to be 7, got %d", n) + + // Commit the first transaction + s.Require().NoError(tx.Commit().Error) + + // IMPORTANT: Use the same connection for another query + query := `select n from generate_series(1,10) n` + rows, err := s.gormDB.Raw(query).Rows() + s.Require().NoError(err) + rows.Close() + + // Start a new transaction + tx2 := s.gormDB.Begin() + // tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + s.Require().NoError(tx.Error) + los2 := &LargeObjects{tx2} + + // Reopen the large object in the new transaction + obj2, err := los2.Open(id, ModeWrite|ModeRead) + s.Require().NoError(err) + + pos, err := obj2.Seek(1, 0) + s.Require().NoError(err) + s.Require().EqualValues(1, pos) + + res := make([]byte, 6) + n, err = obj2.Read(res) + s.Require().NoError(err) + s.Require().Equal("esting", string(res)) + s.Require().Equal(6, n) + + n, err = obj2.Read(res) + s.Require().Equal(err, io.EOF) + s.Require().Zero(n) + + pos, err = obj2.Tell() + s.Require().NoError(err) + s.Require().EqualValues(7, pos) + + n, err = obj2.Truncate(1) + s.Require().NoError(err) + + pos, err = obj2.Seek(-1, 2) + s.Require().NoError(err) + s.Require().Zero(pos) + + res = make([]byte, 2) + n, err = obj2.Read(res) + s.Require().Equal(io.EOF, err) + s.Require().Equal(1, n) + s.Require().EqualValues('t', res[0]) + + err = obj2.Close() + s.Require().NoError(err) + + err = los2.Unlink(id) + s.Require().NoError(err) + + _, err = los2.Open(id, ModeRead) + s.Require().Contains(err.Error(), "does not exist (SQLSTATE 42704)") +} From bea7011583010b84b04edc7f4812f3f3e3e7d087 Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 17 May 2023 10:47:38 -0700 Subject: [PATCH 32/40] merge and style --- .../m_180_to_m_181_move_to_blobstore/migration_test.go | 3 +++ .../m_180_to_m_181_move_to_blobstore/schema/gen.go | 3 +++ pkg/postgres/gorm/largeobject/large_objects.go | 6 +++++- pkg/postgres/gorm/largeobject/large_objects_test.go | 9 ++++++--- tools/roxvet/analyzers/validateimports/analyzer.go | 1 + 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go index 020adfc2da288..7fa5c99e99bf8 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -1,3 +1,5 @@ +//go:build sql_integration + package m179tom180 import ( @@ -44,6 +46,7 @@ func (s *categoriesMigrationTestSuite) TestMigration() { reader := bytes.NewBuffer(randomData) file, err := os.CreateTemp("", "move-blob") + s.Require().NoError(err) defer func() { s.NoError(file.Close()) s.NoError(os.Remove(file.Name())) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go index 8265e535a968e..b1c32a96973f2 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/schema/gen.go @@ -1,3 +1,6 @@ package schema +// TODO(ROX-17180): Remove this auto-generation at the beginning of 4.2 or at least +// before we made schema change to Blob store after first release. + //go:generate pg-schema-migration-helper --type=storage.Blob diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index 262d808ef02fe..7d7753465f5a2 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -15,11 +15,14 @@ type LargeObjects struct { *gorm.DB } +// Mode is the open mode for large object type Mode int32 const ( + // ModeWrite is bitmap for write operation on large object ModeWrite Mode = 0x20000 - ModeRead Mode = 0x40000 + // ModeRead is bitmap for read operation on large object + ModeRead Mode = 0x40000 ) // Create creates a new large object with an unused OID assigned @@ -74,6 +77,7 @@ func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { return err } +// Get gets the content of the large object and write it to the writer. func (o *LargeObjects) Get(oid uint32, w io.Writer) error { obj, err := o.Open(oid, ModeRead) if err != nil { diff --git a/pkg/postgres/gorm/largeobject/large_objects_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go index b042f280e50c1..d1508994a11f8 100644 --- a/pkg/postgres/gorm/largeobject/large_objects_test.go +++ b/pkg/postgres/gorm/largeobject/large_objects_test.go @@ -1,3 +1,5 @@ +//go:build sql_integration + package largeobject import ( @@ -110,7 +112,7 @@ func (s *GormUtilsTestSuite) TestLargeObjectSingleTransaction() { s.Require().NoError(err) s.Require().EqualValues(7, pos) - n, err = obj.Truncate(1) + _, err = obj.Truncate(1) s.Require().NoError(err) pos, err = obj.Seek(-1, 2) @@ -154,7 +156,8 @@ func (s *GormUtilsTestSuite) TestLargeObjectMultipleTransactions() { query := `select n from generate_series(1,10) n` rows, err := s.gormDB.Raw(query).Rows() s.Require().NoError(err) - rows.Close() + s.Require().NoError(rows.Err()) + s.NoError(rows.Close()) // Start a new transaction tx2 := s.gormDB.Begin() @@ -184,7 +187,7 @@ func (s *GormUtilsTestSuite) TestLargeObjectMultipleTransactions() { s.Require().NoError(err) s.Require().EqualValues(7, pos) - n, err = obj2.Truncate(1) + _, err = obj2.Truncate(1) s.Require().NoError(err) pos, err = obj2.Seek(-1, 2) diff --git a/tools/roxvet/analyzers/validateimports/analyzer.go b/tools/roxvet/analyzers/validateimports/analyzer.go index f2c145b10c135..751d0c56f5e65 100644 --- a/tools/roxvet/analyzers/validateimports/analyzer.go +++ b/tools/roxvet/analyzers/validateimports/analyzer.go @@ -256,6 +256,7 @@ func verifyImportsFromAllowedPackagesOnly(pass *analysis.Pass, imports []*ast.Im "pkg/migrations", "pkg/nodes/converter", "pkg/policyutils", + "pkg/postgres/gorm", "pkg/postgres/pgadmin", "pkg/postgres/pgconfig", "pkg/postgres/pgtest", From d37f3715f0cead1881a2eb07e495dcf3541fa12d Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 17 May 2023 10:49:16 -0700 Subject: [PATCH 33/40] small review first --- central/scannerdefinitions/handler/handler_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index d83eec49a4359..30f8e8fe5c9a7 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -1,3 +1,5 @@ +//go:build sql_integration + package handler import ( From 5a803d0781a33740191609e30967f9ac4e034e84 Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 17 May 2023 13:46:25 -0700 Subject: [PATCH 34/40] Resolve review comments --- central/scannerdefinitions/handler/handler.go | 6 +++--- central/scannerdefinitions/handler/handler_test.go | 6 +----- central/scannerdefinitions/handler/options.go | 1 + central/scannerdefinitions/handler/updater_test.go | 4 ++++ 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/central/scannerdefinitions/handler/handler.go b/central/scannerdefinitions/handler/handler.go index 77c1a47012bd8..59af1ded1b511 100644 --- a/central/scannerdefinitions/handler/handler.go +++ b/central/scannerdefinitions/handler/handler.go @@ -353,12 +353,12 @@ func (h *httpHandler) openMostRecentDefinitions(ctx context.Context, uuid string // Open both the "online" and "offline", and save their modification times. var onlineFile *vulDefFile - onlineOsFile, onlineTime, err := u.file.Open() + onlineOSFile, onlineTime, err := u.file.Open() if err != nil { return } - if onlineOsFile != nil { - onlineFile = &vulDefFile{File: onlineOsFile, modTime: onlineTime} + if onlineOSFile != nil { + onlineFile = &vulDefFile{File: onlineOSFile, modTime: onlineTime} } defer toClose(onlineFile) diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index 30f8e8fe5c9a7..f220456f07803 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -170,13 +170,9 @@ func (s *handlerTestSuite) TestServeHTTP_Online_Get() { assert.Empty(t, w.Data.String()) } -func mustSetModTime(t *testing.T, path string, modTime time.Time) { - require.NoError(t, os.Chtimes(path, time.Now(), modTime)) -} - func (s *handlerTestSuite) mustWriteOffline(content string, modTime time.Time) { modifiedTime, err := types.TimestampProto(modTime) - s.NoError(err) + s.Require().NoError(err) blob := &storage.Blob{ Name: offlineScannerDefinitionBlobName, ModifiedTime: modifiedTime, diff --git a/central/scannerdefinitions/handler/options.go b/central/scannerdefinitions/handler/options.go index 148814833e101..32ff7569ade69 100644 --- a/central/scannerdefinitions/handler/options.go +++ b/central/scannerdefinitions/handler/options.go @@ -7,6 +7,7 @@ type handlerOpts struct { // The following are options which are only respected in online-mode. // cleanupInterval sets the interval for cleaning up updaters. cleanupInterval *time.Duration + // cleanupAge sets the age after which an updater should be cleaned. cleanupAge *time.Duration } diff --git a/central/scannerdefinitions/handler/updater_test.go b/central/scannerdefinitions/handler/updater_test.go index d7cb3647c89aa..549af0c586813 100644 --- a/central/scannerdefinitions/handler/updater_test.go +++ b/central/scannerdefinitions/handler/updater_test.go @@ -54,3 +54,7 @@ func mustGetModTime(t *testing.T, path string) time.Time { require.NoError(t, err) return fi.ModTime().UTC() } + +func mustSetModTime(t *testing.T, path string, modTime time.Time) { + require.NoError(t, os.Chtimes(path, time.Now(), modTime)) +} From f7190ff704f88dc189ab9b8ec576c6a9215eb9bd Mon Sep 17 00:00:00 2001 From: cdu Date: Wed, 17 May 2023 13:52:38 -0700 Subject: [PATCH 35/40] add migration --- migrator/runner/all.go | 1 + 1 file changed, 1 insertion(+) diff --git a/migrator/runner/all.go b/migrator/runner/all.go index fad02c8cfafbb..33872b711b938 100644 --- a/migrator/runner/all.go +++ b/migrator/runner/all.go @@ -136,4 +136,5 @@ import ( _ "github.com/stackrox/rox/migrator/migrations/m_177_to_m_178_group_permissions" _ "github.com/stackrox/rox/migrator/migrations/m_178_to_m_179_embedded_collections_search_label" _ "github.com/stackrox/rox/migrator/migrations/m_179_to_m_180_openshift_policy_exclusions" + _ "github.com/stackrox/rox/migrator/migrations/m_180_to_m_181_move_to_blobstore" ) From 09b16a5670e060129ad479004902b6822e515c7a Mon Sep 17 00:00:00 2001 From: cdu Date: Thu, 18 May 2023 09:47:41 -0700 Subject: [PATCH 36/40] review comments --- .../migrations/m_180_to_m_181_move_to_blobstore/migration.go | 2 +- .../m_180_to_m_181_move_to_blobstore/migration_test.go | 2 +- pkg/migrations/internal/seq_num.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go index 6fcf38125c5ab..a6aa54d2c11a7 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go @@ -1,4 +1,4 @@ -package m179tom180 +package m180tom181 import ( "context" diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go index 7fa5c99e99bf8..3330938ffb120 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -1,6 +1,6 @@ //go:build sql_integration -package m179tom180 +package m180tom181 import ( "bytes" diff --git a/pkg/migrations/internal/seq_num.go b/pkg/migrations/internal/seq_num.go index c0c65febd0877..063a38da4d67f 100644 --- a/pkg/migrations/internal/seq_num.go +++ b/pkg/migrations/internal/seq_num.go @@ -4,13 +4,13 @@ var ( // CurrentDBVersionSeqNum is the current DB version number. // This must be incremented every time we write a migration. // It is a shared constant between central and the migrator binary. - CurrentDBVersionSeqNum = 180 + CurrentDBVersionSeqNum = 181 // MinimumSupportedDBVersionSeqNum is the minimum DB version number // that is supported by this database. This is used in case of rollbacks in // the event that a major change introduced an incompatible schema update we // can inform that a rollback below this is not supported by the database - MinimumSupportedDBVersionSeqNum = 180 + MinimumSupportedDBVersionSeqNum = 181 // LastRocksDBVersionSeqNum is the sequence number for the last RocksDB version. LastRocksDBVersionSeqNum = 112 From c0346c0e12387cbedc5c995feccdfc9287d86cd8 Mon Sep 17 00:00:00 2001 From: cdu Date: Thu, 18 May 2023 10:42:35 -0700 Subject: [PATCH 37/40] close --- .../migrations/m_180_to_m_181_move_to_blobstore/migration.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go index a6aa54d2c11a7..9a257282bb78f 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go @@ -15,6 +15,7 @@ import ( "github.com/stackrox/rox/pkg/postgres/gorm/largeobject" "github.com/stackrox/rox/pkg/postgres/pgutils" "github.com/stackrox/rox/pkg/sac" + "github.com/stackrox/rox/pkg/utils" "gorm.io/gorm" ) @@ -74,6 +75,7 @@ func moveScannerDefination(tx *gorm.DB) error { if os.IsNotExist(err) { return nil } + defer utils.IgnoreError(fd.Close) if err != nil { return errors.Wrapf(err, "failed to open %s", scannerDefPath) } From 60e8a15a9a78674853df6b635265b246d50f4a51 Mon Sep 17 00:00:00 2001 From: cdu Date: Thu, 18 May 2023 11:12:31 -0700 Subject: [PATCH 38/40] review --- .../m_180_to_m_181_move_to_blobstore/migration_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go index 3330938ffb120..7b0b90f333572 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -56,6 +56,8 @@ func (s *categoriesMigrationTestSuite) TestMigration() { s.Require().NoError(err) s.Require().EqualValues(size, n) + fileInfo, err := file.Stat() + s.Require().NoError(err) // Migrate s.Require().NoError(moveToBlobs(s.db.GetGormDB())) @@ -69,8 +71,6 @@ func (s *categoriesMigrationTestSuite) TestMigration() { s.Equal(scannerDefBlobName, blob.GetName()) s.EqualValues(size, blob.GetLength()) - fileInfo, err := file.Stat() - s.Require().NoError(err) modTime := pgutils.NilOrTime(blob.GetModifiedTime()) s.Equal(fileInfo.ModTime().UTC(), modTime.UTC()) From 0799e65ccca0ce6f090ce0ec841debf9383fa6ac Mon Sep 17 00:00:00 2001 From: cdu Date: Thu, 18 May 2023 14:17:30 -0700 Subject: [PATCH 39/40] size and minor change --- central/blob/datastore/store/store.go | 12 ++++++++++++ central/blob/datastore/store/store_test.go | 3 ++- central/blob/snapshot/snapshot_test.go | 2 +- central/scannerdefinitions/handler/handler_test.go | 2 ++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/central/blob/datastore/store/store.go b/central/blob/datastore/store/store.go index d8fed9f6ff0b9..d0aa1ce0bebac 100644 --- a/central/blob/datastore/store/store.go +++ b/central/blob/datastore/store/store.go @@ -77,8 +77,11 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea obj.Oid = oid } buf := make([]byte, 1024*1024) + + var totalRead int64 for { nRead, err := reader.Read(buf) + totalRead += int64(nRead) if nRead != 0 { if _, err := lo.Write(buf[:nRead]); err != nil { @@ -97,6 +100,9 @@ func (s *storeImpl) Upsert(ctx context.Context, obj *storage.Blob, reader io.Rea if err := lo.Close(); err != nil { return wrapRollback(ctx, tx, errors.Wrap(err, "closing large object for blob")) } + if totalRead != obj.GetLength() { + return wrapRollback(ctx, tx, errors.Errorf("Blob metadata mismatch. Blob metadata shows %d in length, but data has length of %d", obj.GetLength(), totalRead)) + } if err := s.store.Upsert(ctx, obj); err != nil { return wrapRollback(ctx, tx, errors.Wrapf(err, "error upserting blob %q", obj.GetName())) @@ -125,8 +131,10 @@ func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*st } buf := make([]byte, 1024*1024) + var totalRead int64 for { nRead, err := lo.Read(buf) + totalRead += int64(nRead) // nRead can be non-zero when err == io.EOF if nRead != 0 { @@ -147,6 +155,10 @@ func (s *storeImpl) Get(ctx context.Context, name string, writer io.Writer) (*st return nil, false, wrapRollback(ctx, tx, err) } + if totalRead != existingBlob.GetLength() { + return nil, false, wrapRollback(ctx, tx, errors.Errorf("Blob %s corrupted. Blob metadata shows %d in length, but data has length of %d", existingBlob.GetName(), existingBlob.GetLength(), totalRead)) + } + return existingBlob, true, tx.Commit(ctx) } diff --git a/central/blob/datastore/store/store_test.go b/central/blob/datastore/store/store_test.go index 5ecfa14602b9c..7fbc27b4e64ba 100644 --- a/central/blob/datastore/store/store_test.go +++ b/central/blob/datastore/store/store_test.go @@ -43,9 +43,11 @@ func (s *BlobsStoreSuite) TearDownSuite() { func (s *BlobsStoreSuite) TestStore() { ctx := sac.WithAllAccess(context.Background()) + size := 1024*1024 + 16 insertBlob := &storage.Blob{ Name: "test", + Length: int64(size), LastUpdated: timestamp.TimestampNow(), ModifiedTime: timestamp.TimestampNow(), } @@ -55,7 +57,6 @@ func (s *BlobsStoreSuite) TestStore() { s.Require().NoError(err) s.Require().False(exists) - size := 1024*1024 + 16 randomData := make([]byte, size) _, err = rand.Read(randomData) s.NoError(err) diff --git a/central/blob/snapshot/snapshot_test.go b/central/blob/snapshot/snapshot_test.go index e57277248404b..75141cc0ef747 100644 --- a/central/blob/snapshot/snapshot_test.go +++ b/central/blob/snapshot/snapshot_test.go @@ -26,7 +26,7 @@ type snapshotTestSuite struct { testDB *pgtest.TestPostgres } -func TestBlobsStore(t *testing.T) { +func TestBlobsStoreSnapshot(t *testing.T) { suite.Run(t, new(snapshotTestSuite)) } diff --git a/central/scannerdefinitions/handler/handler_test.go b/central/scannerdefinitions/handler/handler_test.go index f220456f07803..ee523f0965998 100644 --- a/central/scannerdefinitions/handler/handler_test.go +++ b/central/scannerdefinitions/handler/handler_test.go @@ -175,7 +175,9 @@ func (s *handlerTestSuite) mustWriteOffline(content string, modTime time.Time) { s.Require().NoError(err) blob := &storage.Blob{ Name: offlineScannerDefinitionBlobName, + Length: int64(len(content)), ModifiedTime: modifiedTime, + LastUpdated: types.TimestampNow(), } s.Require().NoError(s.datastore.Upsert(s.ctx, blob, bytes.NewBuffer([]byte(content)))) } From 2bb5551940dcef860eedfdf050674169fe9b505b Mon Sep 17 00:00:00 2001 From: cdu Date: Fri, 19 May 2023 10:25:33 -0700 Subject: [PATCH 40/40] Resolve review comment --- .../migration.go | 31 +++++++++---------- .../migration_test.go | 10 +++--- .../gorm/largeobject/large_objects.go | 20 +++--------- .../gorm/largeobject/large_objects_test.go | 1 - 4 files changed, 25 insertions(+), 37 deletions(-) diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go index 9a257282bb78f..3083ef6e2aa13 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration.go @@ -48,42 +48,41 @@ func moveToBlobs(db *gorm.DB) (err error) { pgutils.CreateTableFromModel(context.Background(), db, schema.CreateTableBlobsStmt) tx := db.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - if err = moveScannerDefination(tx); err != nil { + if err = moveScannerDefinitions(tx); err != nil { result := tx.Rollback() if result.Error != nil { - return result.Error + log.Warnf("failed to rollback with error %v", result.Error) } - return err + return errors.Wrap(err, "failed to move scanner definition to blob store.") } return tx.Commit().Error } -func moveScannerDefination(tx *gorm.DB) error { - stat, err := os.Stat(scannerDefPath) +func moveScannerDefinitions(tx *gorm.DB) error { + fd, err := os.Open(scannerDefPath) + if os.IsNotExist(err) { + return nil + } if err != nil { - if os.IsNotExist(err) || stat.IsDir() { - return nil - } - return err + return errors.Wrapf(err, "failed to open %s", scannerDefPath) } - modTime, err := timestamp.TimestampProto(stat.ModTime()) + defer utils.IgnoreError(fd.Close) + stat, err := fd.Stat() if err != nil { - return errors.Wrapf(err, "invalid timestamp %v", stat.ModTime()) + return err } - fd, err := os.Open(scannerDefPath) - if os.IsNotExist(err) { + if stat.IsDir() { return nil } - defer utils.IgnoreError(fd.Close) + modTime, err := timestamp.TimestampProto(stat.ModTime()) if err != nil { - return errors.Wrapf(err, "failed to open %s", scannerDefPath) + return errors.Wrapf(err, "invalid timestamp %v", stat.ModTime()) } // Prepare blob blob := &storage.Blob{ Name: scannerDefBlobName, - Oid: 0, Length: stat.Size(), LastUpdated: timestamp.TimestampNow(), ModifiedTime: modTime, diff --git a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go index 7b0b90f333572..0f6b8284a63cf 100644 --- a/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go +++ b/migrator/migrations/m_180_to_m_181_move_to_blobstore/migration_test.go @@ -16,25 +16,25 @@ import ( "github.com/stretchr/testify/suite" ) -type categoriesMigrationTestSuite struct { +type blobMigrationTestSuite struct { suite.Suite db *pghelper.TestPostgres } func TestMigration(t *testing.T) { - suite.Run(t, new(categoriesMigrationTestSuite)) + suite.Run(t, new(blobMigrationTestSuite)) } -func (s *categoriesMigrationTestSuite) SetupTest() { +func (s *blobMigrationTestSuite) SetupTest() { s.db = pghelper.ForT(s.T(), true) } -func (s *categoriesMigrationTestSuite) TearDownTest() { +func (s *blobMigrationTestSuite) TearDownTest() { s.db.Teardown(s.T()) } -func (s *categoriesMigrationTestSuite) TestMigration() { +func (s *blobMigrationTestSuite) TestMigration() { // Nothing to migrate s.Require().NoError(moveToBlobs(s.db.GetGormDB())) diff --git a/pkg/postgres/gorm/largeobject/large_objects.go b/pkg/postgres/gorm/largeobject/large_objects.go index 7d7753465f5a2..de6deb16db1de 100644 --- a/pkg/postgres/gorm/largeobject/large_objects.go +++ b/pkg/postgres/gorm/largeobject/large_objects.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm" ) -// LargeObjects is used to access the large objects API with gorm CRM. +// LargeObjects is used to access the large objects API with gorm ORM. // // This is originally created with similar API with existing github.com/jackc/pgx // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html @@ -62,19 +62,12 @@ func (o *LargeObjects) Upsert(oid uint32, r io.Reader) error { if err != nil { return err } - defer func() { - err = obj.Close() - }() - _, err = obj.Truncate(0) - if err != nil { - return err + if _, err = obj.Truncate(0); err != nil { + return errors.Join(err, obj.Close()) } _, err = io.Copy(obj, r) - if err != nil { - return err - } - return err + return errors.Join(err, obj.Close()) } // Get gets the content of the large object and write it to the writer. @@ -162,8 +155,5 @@ func (o *LargeObject) Close() error { // wrapClose closes the large object and returns error if failed. Otherwise, it // returns err func (o *LargeObject) wrapClose(err error) error { - if closeErr := o.Close(); closeErr != nil { - return closeErr - } - return err + return errors.Join(err, o.Close()) } diff --git a/pkg/postgres/gorm/largeobject/large_objects_test.go b/pkg/postgres/gorm/largeobject/large_objects_test.go index d1508994a11f8..ae7b9e9f5dbd1 100644 --- a/pkg/postgres/gorm/largeobject/large_objects_test.go +++ b/pkg/postgres/gorm/largeobject/large_objects_test.go @@ -161,7 +161,6 @@ func (s *GormUtilsTestSuite) TestLargeObjectMultipleTransactions() { // Start a new transaction tx2 := s.gormDB.Begin() - // tx := gormDB.Begin(&sql.TxOptions{Isolation: sql.LevelRepeatableRead}) s.Require().NoError(tx.Error) los2 := &LargeObjects{tx2}