diff --git a/models/db/context.go b/models/db/context.go index 43f612518a..171e26b933 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -6,6 +6,12 @@ package db import ( "context" "database/sql" + "errors" + "runtime" + "slices" + "sync" + + "code.gitea.io/gitea/modules/setting" "xorm.io/builder" "xorm.io/xorm" @@ -15,45 +21,23 @@ import ( // will be overwritten by Init with HammerContext var DefaultContext context.Context -// contextKey is a value for use with context.WithValue. -type contextKey struct { - name string -} +type engineContextKeyType struct{} -// enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context -var ( - enginedContextKey = &contextKey{"engined"} - _ Engined = &Context{} -) +var engineContextKey = engineContextKeyType{} // Context represents a db context type Context struct { context.Context - e Engine - transaction bool + engine Engine } -func newContext(ctx context.Context, e Engine, transaction bool) *Context { - return &Context{ - Context: ctx, - e: e, - transaction: transaction, - } -} - -// InTransaction if context is in a transaction -func (ctx *Context) InTransaction() bool { - return ctx.transaction -} - -// Engine returns db engine -func (ctx *Context) Engine() Engine { - return ctx.e +func newContext(ctx context.Context, e Engine) *Context { + return &Context{Context: ctx, engine: e} } // Value shadows Value for context.Context but allows us to get ourselves and an Engined object func (ctx *Context) Value(key any) any { - if key == enginedContextKey { + if key == engineContextKey { return ctx } return ctx.Context.Value(key) @@ -61,30 +45,66 @@ func (ctx *Context) Value(key any) any { // WithContext returns this engine tied to this context func (ctx *Context) WithContext(other context.Context) *Context { - return newContext(ctx, ctx.e.Context(other), ctx.transaction) + return newContext(ctx, ctx.engine.Context(other)) } -// Engined structs provide an Engine -type Engined interface { - Engine() Engine +var ( + contextSafetyOnce sync.Once + contextSafetyDeniedFuncPCs []uintptr +) + +func contextSafetyCheck(e Engine) { + if setting.IsProd && !setting.IsInTesting { + return + } + if e == nil { + return + } + // Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed. + contextSafetyOnce.Do(func() { + // try to figure out the bad functions to deny + type m struct{} + _ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error { + callers := make([]uintptr, 32) + callerNum := runtime.Callers(1, callers) + for i := 0; i < callerNum; i++ { + if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" { + contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i]) + } + } + return nil + }) + if len(contextSafetyDeniedFuncPCs) != 1 { + panic(errors.New("unable to determine the functions to deny")) + } + }) + + // it should be very fast: xxxx ns/op + callers := make([]uintptr, 32) + callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine + for i := 0; i < callerNum; i++ { + if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) { + panic(errors.New("using database context in an iterator would cause corrupted results")) + } + } } -// GetEngine will get a db Engine from this context or return an Engine restricted to this context +// GetEngine gets an existing db Engine/Statement or creates a new Session func GetEngine(ctx context.Context) Engine { - if e := getEngine(ctx); e != nil { + if e := getExistingEngine(ctx); e != nil { return e } return x.Context(ctx) } -// getEngine will get a db Engine from this context or return nil -func getEngine(ctx context.Context) Engine { - if engined, ok := ctx.(Engined); ok { - return engined.Engine() +// getExistingEngine gets an existing db Engine/Statement from this context or returns nil +func getExistingEngine(ctx context.Context) (e Engine) { + defer func() { contextSafetyCheck(e) }() + if engined, ok := ctx.(*Context); ok { + return engined.engine } - enginedInterface := ctx.Value(enginedContextKey) - if enginedInterface != nil { - return enginedInterface.(Engined).Engine() + if engined, ok := ctx.Value(engineContextKey).(*Context); ok { + return engined.engine } return nil } @@ -132,23 +152,23 @@ func (c *halfCommitter) Close() error { // d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback. func TxContext(parentCtx context.Context) (*Context, Committer, error) { if sess, ok := inTransaction(parentCtx); ok { - return newContext(parentCtx, sess, true), &halfCommitter{committer: sess}, nil + return newContext(parentCtx, sess), &halfCommitter{committer: sess}, nil } sess := x.NewSession() if err := sess.Begin(); err != nil { - sess.Close() + _ = sess.Close() return nil, nil, err } - return newContext(DefaultContext, sess, true), sess, nil + return newContext(DefaultContext, sess), sess, nil } // WithTx represents executing database operations on a transaction, if the transaction exist, // this function will reuse it otherwise will create a new one and close it when finished. func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error { if sess, ok := inTransaction(parentCtx); ok { - err := f(newContext(parentCtx, sess, true)) + err := f(newContext(parentCtx, sess)) if err != nil { // rollback immediately, in case the caller ignores returned error and tries to commit the transaction. _ = sess.Close() @@ -165,7 +185,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) return err } - if err := f(newContext(parentCtx, sess, true)); err != nil { + if err := f(newContext(parentCtx, sess)); err != nil { return err } @@ -312,7 +332,7 @@ func InTransaction(ctx context.Context) bool { } func inTransaction(ctx context.Context) (*xorm.Session, bool) { - e := getEngine(ctx) + e := getExistingEngine(ctx) if e == nil { return nil, false } diff --git a/models/db/context_test.go b/models/db/context_test.go index 95a01d4a26..e8c6b74d93 100644 --- a/models/db/context_test.go +++ b/models/db/context_test.go @@ -84,3 +84,47 @@ func TestTxContext(t *testing.T) { })) } } + +func TestContextSafety(t *testing.T) { + type TestModel1 struct { + ID int64 + } + type TestModel2 struct { + ID int64 + } + assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{})) + assert.NoError(t, db.TruncateBeans(db.DefaultContext, &TestModel1{}, &TestModel2{})) + testCount := 10 + for i := 1; i <= testCount; i++ { + assert.NoError(t, db.Insert(db.DefaultContext, &TestModel1{ID: int64(i)})) + assert.NoError(t, db.Insert(db.DefaultContext, &TestModel2{ID: int64(-i)})) + } + + actualCount := 0 + // here: db.GetEngine(db.DefaultContext) is a new *Session created from *Engine + _ = db.WithTx(db.DefaultContext, func(ctx context.Context) error { + _ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error { + // here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false, + // and the internal states (including "cond" and others) are always there and not be reset in this callback. + m1 := bean.(*TestModel1) + assert.EqualValues(t, i+1, m1.ID) + + // here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ... + // and it conflicts with the "Iterate"'s internal states. + // has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID}) + + actualCount++ + return nil + }) + return nil + }) + assert.EqualValues(t, testCount, actualCount) + + // deny the bad usages + assert.PanicsWithError(t, "using database context in an iterator would cause corrupted results", func() { + _ = unittest.GetXORMEngine().Iterate(&TestModel1{}, func(i int, bean any) error { + _ = db.GetEngine(db.DefaultContext) + return nil + }) + }) +} diff --git a/models/db/engine.go b/models/db/engine.go index 847ba58c26..e50a8580bf 100755 --- a/models/db/engine.go +++ b/models/db/engine.go @@ -161,10 +161,7 @@ func InitEngine(ctx context.Context) error { // SetDefaultEngine sets the default engine for db func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) { x = eng - DefaultContext = &Context{ - Context: ctx, - e: x, - } + DefaultContext = &Context{Context: ctx, engine: x} } // UnsetDefaultEngine closes and unsets the default engine diff --git a/models/db/install/db.go b/models/db/install/db.go index d4c1139637..1b3b2ec3e9 100644 --- a/models/db/install/db.go +++ b/models/db/install/db.go @@ -11,7 +11,7 @@ import ( ) func getXORMEngine() *xorm.Engine { - return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine) + return db.GetEngine(db.DefaultContext).(*xorm.Engine) } // CheckDatabaseConnection checks the database connection diff --git a/models/db/iterate.go b/models/db/iterate.go index e1caefa72b..481be1b4b7 100644 --- a/models/db/iterate.go +++ b/models/db/iterate.go @@ -11,7 +11,7 @@ import ( "xorm.io/builder" ) -// Iterate iterate all the Bean object +// Iterate iterates all the Bean object func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error { var start int batchSize := setting.Database.IterateBufferSize diff --git a/models/packages/debian/search.go b/models/packages/debian/search.go index 77c4a18462..5333d0c6e4 100644 --- a/models/packages/debian/search.go +++ b/models/packages/debian/search.go @@ -75,26 +75,27 @@ func ExistPackages(ctx context.Context, opts *PackageSearchOptions) (bool, error } // SearchPackages gets the packages matching the search options -func SearchPackages(ctx context.Context, opts *PackageSearchOptions, iter func(*packages.PackageFileDescriptor)) error { - return db.GetEngine(ctx). +func SearchPackages(ctx context.Context, opts *PackageSearchOptions) ([]*packages.PackageFileDescriptor, error) { + var pkgFiles []*packages.PackageFile + err := db.GetEngine(ctx). Table("package_file"). Select("package_file.*"). Join("INNER", "package_version", "package_version.id = package_file.version_id"). Join("INNER", "package", "package.id = package_version.package_id"). Where(opts.toCond()). - Asc("package.lower_name", "package_version.created_unix"). - Iterate(new(packages.PackageFile), func(_ int, bean any) error { - pf := bean.(*packages.PackageFile) - - pfd, err := packages.GetPackageFileDescriptor(ctx, pf) - if err != nil { - return err - } - - iter(pfd) - - return nil - }) + Asc("package.lower_name", "package_version.created_unix").Find(&pkgFiles) + if err != nil { + return nil, err + } + pfds := make([]*packages.PackageFileDescriptor, 0, len(pkgFiles)) + for _, pf := range pkgFiles { + pfd, err := packages.GetPackageFileDescriptor(ctx, pf) + if err != nil { + return nil, err + } + pfds = append(pfds, pfd) + } + return pfds, nil } // GetDistributions gets all available distributions diff --git a/models/unittest/fixtures.go b/models/unittest/fixtures.go index c653ce1e38..4dde5410d6 100644 --- a/models/unittest/fixtures.go +++ b/models/unittest/fixtures.go @@ -25,7 +25,7 @@ func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) { if len(engine) == 1 { return engine[0] } - return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine) + return db.GetEngine(db.DefaultContext).(*xorm.Engine) } // InitFixtures initialize test fixtures for a test database diff --git a/services/packages/cleanup/cleanup.go b/services/packages/cleanup/cleanup.go index 5d5120c6a0..d7c9355da5 100644 --- a/services/packages/cleanup/cleanup.go +++ b/services/packages/cleanup/cleanup.go @@ -22,7 +22,7 @@ import ( rpm_service "code.gitea.io/gitea/services/packages/rpm" ) -// Task method to execute cleanup rules and cleanup expired package data +// CleanupTask executes cleanup rules and cleanup expired package data func CleanupTask(ctx context.Context, olderThan time.Duration) error { if err := ExecuteCleanupRules(ctx); err != nil { return err diff --git a/services/packages/debian/repository.go b/services/packages/debian/repository.go index 611faa6ade..13e98a820e 100644 --- a/services/packages/debian/repository.go +++ b/services/packages/debian/repository.go @@ -206,7 +206,11 @@ func buildPackagesIndices(ctx context.Context, ownerID int64, repoVersion *packa w := io.MultiWriter(packagesContent, gzw, xzw) addSeparator := false - if err := debian_model.SearchPackages(ctx, opts, func(pfd *packages_model.PackageFileDescriptor) { + pfds, err := debian_model.SearchPackages(ctx, opts) + if err != nil { + return err + } + for _, pfd := range pfds { if addSeparator { fmt.Fprintln(w) } @@ -220,10 +224,7 @@ func buildPackagesIndices(ctx context.Context, ownerID int64, repoVersion *packa fmt.Fprintf(w, "SHA1: %s\n", pfd.Blob.HashSHA1) fmt.Fprintf(w, "SHA256: %s\n", pfd.Blob.HashSHA256) fmt.Fprintf(w, "SHA512: %s\n", pfd.Blob.HashSHA512) - }); err != nil { - return err } - gzw.Close() xzw.Close() diff --git a/tests/integration/api_packages_debian_test.go b/tests/integration/api_packages_debian_test.go index 05979fccb5..98027d774c 100644 --- a/tests/integration/api_packages_debian_test.go +++ b/tests/integration/api_packages_debian_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "testing" @@ -19,6 +20,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/base" debian_module "code.gitea.io/gitea/modules/packages/debian" + packages_cleanup_service "code.gitea.io/gitea/services/packages/cleanup" "code.gitea.io/gitea/tests" "github.com/blakesmith/ar" @@ -263,4 +265,37 @@ func TestPackageDebian(t *testing.T) { assert.Contains(t, body, "Components: "+strings.Join(components, " ")+"\n") assert.Contains(t, body, "Architectures: "+architectures[1]+"\n") }) + + t.Run("Cleanup", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + rule := &packages.PackageCleanupRule{ + Enabled: true, + RemovePattern: `.*`, + MatchFullName: true, + OwnerID: user.ID, + Type: packages.TypeDebian, + } + + _, err := packages.InsertCleanupRule(db.DefaultContext, rule) + assert.NoError(t, err) + + // When there were a lot of packages (> 50 or 100) and the code used "Iterate" to get all packages, it ever caused bugs, + // because "Iterate" keeps a dangling SQL session but the callback function still uses the same session to execute statements. + // The "Iterate" problem has been checked by TestContextSafety now, so here we only need to check the cleanup logic with a small number + packagesCount := 2 + for i := 0; i < packagesCount; i++ { + uploadURL := fmt.Sprintf("%s/pool/%s/%s/upload", rootURL, "test", "main") + req := NewRequestWithBody(t, "PUT", uploadURL, createArchive(packageName, "1.0."+strconv.Itoa(i), "all")).AddBasicAuth(user.Name) + MakeRequest(t, req, http.StatusCreated) + } + req := NewRequest(t, "GET", fmt.Sprintf("%s/dists/%s/Release", rootURL, "test")) + MakeRequest(t, req, http.StatusOK) + + err = packages_cleanup_service.CleanupTask(db.DefaultContext, 0) + assert.NoError(t, err) + + req = NewRequest(t, "GET", fmt.Sprintf("%s/dists/%s/Release", rootURL, "test")) + MakeRequest(t, req, http.StatusNotFound) + }) }