1
1
mirror of https://github.com/go-gitea/gitea synced 2025-07-23 02:38:35 +00:00

Move db related basic functions to models/db (#17075)

* Move db related basic functions to models/db

* Fix lint

* Fix lint

* Fix test

* Fix lint

* Fix lint

* revert unnecessary change

* Fix test

* Fix wrong replace string

* Use *Context

* Correct committer spelling and fix wrong replaced words

Co-authored-by: zeripath <art27@cantab.net>
This commit is contained in:
Lunny Xiao
2021-09-19 19:49:59 +08:00
committed by GitHub
parent 462306e263
commit a4bfef265d
335 changed files with 4191 additions and 3654 deletions

86
models/db/context.go Normal file
View File

@@ -0,0 +1,86 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
"xorm.io/xorm"
)
// Context represents a db context
type Context struct {
e Engine
}
// Engine returns db engine
func (ctx *Context) Engine() Engine {
return ctx.e
}
// NewSession returns a new session
func (ctx *Context) NewSession() *xorm.Session {
e, ok := ctx.e.(*xorm.Engine)
if ok {
return e.NewSession()
}
return nil
}
// DefaultContext represents a Context with default Engine
func DefaultContext() *Context {
return &Context{x}
}
// Committer represents an interface to Commit or Close the Context
type Committer interface {
Commit() error
Close() error
}
// TxContext represents a transaction Context
func TxContext() (*Context, Committer, error) {
sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
return nil, nil, err
}
return &Context{sess}, sess, nil
}
// WithContext represents executing database operations
func WithContext(f func(ctx *Context) error) error {
return f(&Context{x})
}
// WithTx represents executing database operations on a transaction
func WithTx(f func(ctx *Context) error) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
if err := f(&Context{sess}); err != nil {
return err
}
return sess.Commit()
}
// Iterate iterates the databases and doing something
func Iterate(ctx *Context, tableBean interface{}, cond builder.Cond, fun func(idx int, bean interface{}) error) error {
return ctx.e.Where(cond).
BufferSize(setting.Database.IterateBufferSize).
Iterate(tableBean, fun)
}
// Insert inserts records into database
func Insert(ctx *Context, beans ...interface{}) error {
_, err := ctx.e.Insert(beans...)
return err
}

41
models/db/convert.go Normal file
View File

@@ -0,0 +1,41 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm/schemas"
)
// ConvertUtf8ToUtf8mb4 converts database and tables from utf8 to utf8mb4 if it's mysql and set ROW_FORMAT=dynamic
func ConvertUtf8ToUtf8mb4() error {
if x.Dialect().URI().DBType != schemas.MYSQL {
return nil
}
_, err := x.Exec(fmt.Sprintf("ALTER DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci", setting.Database.Name))
if err != nil {
return err
}
tables, err := x.DBMetas()
if err != nil {
return err
}
for _, table := range tables {
if _, err := x.Exec(fmt.Sprintf("ALTER TABLE `%s` ROW_FORMAT=dynamic;", table.Name)); err != nil {
return err
}
if _, err := x.Exec(fmt.Sprintf("ALTER TABLE `%s` CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci;", table.Name)); err != nil {
return err
}
}
return nil
}

297
models/db/engine.go Executable file
View File

@@ -0,0 +1,297 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2018 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"reflect"
"strings"
"code.gitea.io/gitea/modules/setting"
// Needed for the MySQL driver
_ "github.com/go-sql-driver/mysql"
"xorm.io/xorm"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
// Needed for the Postgresql driver
_ "github.com/lib/pq"
// Needed for the MSSQL driver
_ "github.com/denisenkom/go-mssqldb"
)
var (
x *xorm.Engine
tables []interface{}
initFuncs []func() error
// HasEngine specifies if we have a xorm.Engine
HasEngine bool
)
// Engine represents a xorm engine or session.
type Engine interface {
Table(tableNameOrBean interface{}) *xorm.Session
Count(...interface{}) (int64, error)
Decr(column string, arg ...interface{}) *xorm.Session
Delete(...interface{}) (int64, error)
Exec(...interface{}) (sql.Result, error)
Find(interface{}, ...interface{}) error
Get(interface{}) (bool, error)
ID(interface{}) *xorm.Session
In(string, ...interface{}) *xorm.Session
Incr(column string, arg ...interface{}) *xorm.Session
Insert(...interface{}) (int64, error)
InsertOne(interface{}) (int64, error)
Iterate(interface{}, xorm.IterFunc) error
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *xorm.Session
SQL(interface{}, ...interface{}) *xorm.Session
Where(interface{}, ...interface{}) *xorm.Session
Asc(colNames ...string) *xorm.Session
Desc(colNames ...string) *xorm.Session
Limit(limit int, start ...int) *xorm.Session
SumInt(bean interface{}, columnName string) (res int64, err error)
Sync2(...interface{}) error
Select(string) *xorm.Session
NotIn(string, ...interface{}) *xorm.Session
OrderBy(string) *xorm.Session
Exist(...interface{}) (bool, error)
Distinct(...string) *xorm.Session
Query(...interface{}) ([]map[string][]byte, error)
Cols(...string) *xorm.Session
}
// TableInfo returns table's information via an object
func TableInfo(v interface{}) (*schemas.Table, error) {
return x.TableInfo(v)
}
// DumpTables dump tables information
func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
return x.DumpTables(tables, w, tp...)
}
// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync
func RegisterModel(bean interface{}, initFunc ...func() error) {
tables = append(tables, bean)
if len(initFuncs) > 0 && initFunc[0] != nil {
initFuncs = append(initFuncs, initFunc[0])
}
}
func init() {
gonicNames := []string{"SSL", "UID"}
for _, name := range gonicNames {
names.LintGonicMapper[name] = true
}
}
// GetNewEngine returns a new xorm engine from the configuration
func GetNewEngine() (*xorm.Engine, error) {
connStr, err := setting.DBConnStr()
if err != nil {
return nil, err
}
var engine *xorm.Engine
if setting.Database.UsePostgreSQL && len(setting.Database.Schema) > 0 {
// OK whilst we sort out our schema issues - create a schema aware postgres
registerPostgresSchemaDriver()
engine, err = xorm.NewEngine("postgresschema", connStr)
} else {
engine, err = xorm.NewEngine(setting.Database.Type, connStr)
}
if err != nil {
return nil, err
}
if setting.Database.Type == "mysql" {
engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
} else if setting.Database.Type == "mssql" {
engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
}
engine.SetSchema(setting.Database.Schema)
return engine, nil
}
func syncTables() error {
return x.StoreEngine("InnoDB").Sync2(tables...)
}
// NewTestEngine sets a new test xorm.Engine
func NewTestEngine() (err error) {
x, err = GetNewEngine()
if err != nil {
return fmt.Errorf("Connect to database: %v", err)
}
x.SetMapper(names.GonicMapper{})
x.SetLogger(NewXORMLogger(!setting.IsProd()))
x.ShowSQL(!setting.IsProd())
return syncTables()
}
// SetEngine sets the xorm.Engine
func SetEngine() (err error) {
x, err = GetNewEngine()
if err != nil {
return fmt.Errorf("Failed to connect to database: %v", err)
}
x.SetMapper(names.GonicMapper{})
// WARNING: for serv command, MUST remove the output to os.stdout,
// so use log file to instead print to stdout.
x.SetLogger(NewXORMLogger(setting.Database.LogSQL))
x.ShowSQL(setting.Database.LogSQL)
x.SetMaxOpenConns(setting.Database.MaxOpenConns)
x.SetMaxIdleConns(setting.Database.MaxIdleConns)
x.SetConnMaxLifetime(setting.Database.ConnMaxLifetime)
return nil
}
// NewEngine initializes a new xorm.Engine
// This function must never call .Sync2() if the provided migration function fails.
// When called from the "doctor" command, the migration function is a version check
// that prevents the doctor from fixing anything in the database if the migration level
// is different from the expected value.
func NewEngine(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) {
if err = SetEngine(); err != nil {
return err
}
x.SetDefaultContext(ctx)
if err = x.Ping(); err != nil {
return err
}
if err = migrateFunc(x); err != nil {
return fmt.Errorf("migrate: %v", err)
}
if err = syncTables(); err != nil {
return fmt.Errorf("sync database struct error: %v", err)
}
for _, initFunc := range initFuncs {
if err := initFunc(); err != nil {
return fmt.Errorf("initFunc failed: %v", err)
}
}
return nil
}
// NamesToBean return a list of beans or an error
func NamesToBean(names ...string) ([]interface{}, error) {
beans := []interface{}{}
if len(names) == 0 {
beans = append(beans, tables...)
return beans, nil
}
// Need to map provided names to beans...
beanMap := make(map[string]interface{})
for _, bean := range tables {
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
beanMap[strings.ToLower(x.TableName(bean))] = bean
beanMap[strings.ToLower(x.TableName(bean, true))] = bean
}
gotBean := make(map[interface{}]bool)
for _, name := range names {
bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))]
if !ok {
return nil, fmt.Errorf("No table found that matches: %s", name)
}
if !gotBean[bean] {
beans = append(beans, bean)
gotBean[bean] = true
}
}
return beans, nil
}
// Ping tests if database is alive
func Ping() error {
if x != nil {
return x.Ping()
}
return errors.New("database not configured")
}
// DumpDatabase dumps all data from database according the special database SQL syntax to file system.
func DumpDatabase(filePath, dbType string) error {
var tbs []*schemas.Table
for _, t := range tables {
t, err := x.TableInfo(t)
if err != nil {
return err
}
tbs = append(tbs, t)
}
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
t, err := x.TableInfo(&Version{})
if err != nil {
return err
}
tbs = append(tbs, t)
if len(dbType) > 0 {
return x.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType))
}
return x.DumpTablesToFile(tbs, filePath)
}
// MaxBatchInsertSize returns the table's max batch insert size
func MaxBatchInsertSize(bean interface{}) int {
t, err := x.TableInfo(bean)
if err != nil {
return 50
}
return 999 / len(t.ColumnsSeq())
}
// Count returns records number according struct's fields as database query conditions
func Count(bean interface{}) (int64, error) {
return x.Count(bean)
}
// IsTableNotEmpty returns true if table has at least one record
func IsTableNotEmpty(tableName string) (bool, error) {
return x.Table(tableName).Exist()
}
// DeleteAllRecords will delete all the records of this table
func DeleteAllRecords(tableName string) error {
_, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName))
return err
}
// GetMaxID will return max id of the table
func GetMaxID(beanOrTableName interface{}) (maxID int64, err error) {
_, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
return
}
// FindByMaxID filled results as the condition from database
func FindByMaxID(maxID int64, limit int, results interface{}) error {
return x.Where("id <= ?", maxID).
OrderBy("id DESC").
Limit(limit).
Find(results)
}

110
models/db/index.go Normal file
View File

@@ -0,0 +1,110 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"errors"
"fmt"
"code.gitea.io/gitea/modules/setting"
)
// ResourceIndex represents a resource index which could be used as issue/release and others
// We can create different tables i.e. issue_index, release_index and etc.
type ResourceIndex struct {
GroupID int64 `xorm:"pk"`
MaxIndex int64 `xorm:"index"`
}
// UpsertResourceIndex the function will not return until it acquires the lock or receives an error.
func UpsertResourceIndex(e Engine, tableName string, groupID int64) (err error) {
// An atomic UPSERT operation (INSERT/UPDATE) is the only operation
// that ensures that the key is actually locked.
switch {
case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL:
_, err = e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
"VALUES (?,1) ON CONFLICT (group_id) DO UPDATE SET max_index = %s.max_index+1",
tableName, tableName), groupID)
case setting.Database.UseMySQL:
_, err = e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
"VALUES (?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1", tableName),
groupID)
case setting.Database.UseMSSQL:
// https://weblogs.sqlteam.com/dang/2009/01/31/upsert-race-condition-with-merge/
_, err = e.Exec(fmt.Sprintf("MERGE %s WITH (HOLDLOCK) as target "+
"USING (SELECT ? AS group_id) AS src "+
"ON src.group_id = target.group_id "+
"WHEN MATCHED THEN UPDATE SET target.max_index = target.max_index+1 "+
"WHEN NOT MATCHED THEN INSERT (group_id, max_index) "+
"VALUES (src.group_id, 1);", tableName),
groupID)
default:
return fmt.Errorf("database type not supported")
}
return
}
var (
// ErrResouceOutdated represents an error when request resource outdated
ErrResouceOutdated = errors.New("resource outdated")
// ErrGetResourceIndexFailed represents an error when resource index retries 3 times
ErrGetResourceIndexFailed = errors.New("get resource index failed")
)
const (
maxDupIndexAttempts = 3
)
// GetNextResourceIndex retried 3 times to generate a resource index
func GetNextResourceIndex(tableName string, groupID int64) (int64, error) {
for i := 0; i < maxDupIndexAttempts; i++ {
idx, err := getNextResourceIndex(tableName, groupID)
if err == ErrResouceOutdated {
continue
}
if err != nil {
return 0, err
}
return idx, nil
}
return 0, ErrGetResourceIndexFailed
}
// DeleteResouceIndex delete resource index
func DeleteResouceIndex(e Engine, tableName string, groupID int64) error {
_, err := e.Exec(fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
return err
}
// getNextResourceIndex return the next index
func getNextResourceIndex(tableName string, groupID int64) (int64, error) {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return 0, err
}
var preIdx int64
_, err := sess.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&preIdx)
if err != nil {
return 0, err
}
if err := UpsertResourceIndex(sess, tableName, groupID); err != nil {
return 0, err
}
var curIdx int64
has, err := sess.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ? AND max_index=?", tableName), groupID, preIdx+1).Get(&curIdx)
if err != nil {
return 0, err
}
if !has {
return 0, ErrResouceOutdated
}
if err := sess.Commit(); err != nil {
return 0, err
}
return curIdx, nil
}

107
models/db/log.go Normal file
View File

@@ -0,0 +1,107 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"code.gitea.io/gitea/modules/log"
xormlog "xorm.io/xorm/log"
)
// XORMLogBridge a logger bridge from Logger to xorm
type XORMLogBridge struct {
showSQL bool
logger log.Logger
}
// NewXORMLogger inits a log bridge for xorm
func NewXORMLogger(showSQL bool) xormlog.Logger {
return &XORMLogBridge{
showSQL: showSQL,
logger: log.GetLogger("xorm"),
}
}
const stackLevel = 8
// Log a message with defined skip and at logging level
func (l *XORMLogBridge) Log(skip int, level log.Level, format string, v ...interface{}) error {
return l.logger.Log(skip+1, level, format, v...)
}
// Debug show debug log
func (l *XORMLogBridge) Debug(v ...interface{}) {
_ = l.Log(stackLevel, log.DEBUG, fmt.Sprint(v...))
}
// Debugf show debug log
func (l *XORMLogBridge) Debugf(format string, v ...interface{}) {
_ = l.Log(stackLevel, log.DEBUG, format, v...)
}
// Error show error log
func (l *XORMLogBridge) Error(v ...interface{}) {
_ = l.Log(stackLevel, log.ERROR, fmt.Sprint(v...))
}
// Errorf show error log
func (l *XORMLogBridge) Errorf(format string, v ...interface{}) {
_ = l.Log(stackLevel, log.ERROR, format, v...)
}
// Info show information level log
func (l *XORMLogBridge) Info(v ...interface{}) {
_ = l.Log(stackLevel, log.INFO, fmt.Sprint(v...))
}
// Infof show information level log
func (l *XORMLogBridge) Infof(format string, v ...interface{}) {
_ = l.Log(stackLevel, log.INFO, format, v...)
}
// Warn show warning log
func (l *XORMLogBridge) Warn(v ...interface{}) {
_ = l.Log(stackLevel, log.WARN, fmt.Sprint(v...))
}
// Warnf show warnning log
func (l *XORMLogBridge) Warnf(format string, v ...interface{}) {
_ = l.Log(stackLevel, log.WARN, format, v...)
}
// Level get logger level
func (l *XORMLogBridge) Level() xormlog.LogLevel {
switch l.logger.GetLevel() {
case log.TRACE, log.DEBUG:
return xormlog.LOG_DEBUG
case log.INFO:
return xormlog.LOG_INFO
case log.WARN:
return xormlog.LOG_WARNING
case log.ERROR, log.CRITICAL:
return xormlog.LOG_ERR
}
return xormlog.LOG_OFF
}
// SetLevel set the logger level
func (l *XORMLogBridge) SetLevel(lvl xormlog.LogLevel) {
}
// ShowSQL set if record SQL
func (l *XORMLogBridge) ShowSQL(show ...bool) {
if len(show) > 0 {
l.showSQL = show[0]
} else {
l.showSQL = true
}
}
// IsShowSQL if record SQL
func (l *XORMLogBridge) IsShowSQL() bool {
return l.showSQL
}

14
models/db/main_test.go Normal file
View File

@@ -0,0 +1,14 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"path/filepath"
"testing"
)
func TestMain(m *testing.M) {
MainTest(m, filepath.Join("..", ".."))
}

70
models/db/sequence.go Normal file
View File

@@ -0,0 +1,70 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"regexp"
"code.gitea.io/gitea/modules/setting"
)
// CountBadSequences looks for broken sequences from recreate-table mistakes
func CountBadSequences() (int64, error) {
if !setting.Database.UsePostgreSQL {
return 0, nil
}
sess := x.NewSession()
defer sess.Close()
var sequences []string
schema := x.Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return 0, err
}
sess.Engine().SetSchema(schema)
return int64(len(sequences)), nil
}
// FixBadSequences fixes for broken sequences from recreate-table mistakes
func FixBadSequences() error {
if !setting.Database.UsePostgreSQL {
return nil
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
var sequences []string
schema := sess.Engine().Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return err
}
sess.Engine().SetSchema(schema)
sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`)
for _, sequence := range sequences {
tableName := sequenceRegexp.FindStringSubmatch(sequence)[1]
newSequenceName := tableName + "_id_seq"
if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil {
return err
}
if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil {
return err
}
}
return sess.Commit()
}

View File

@@ -0,0 +1,75 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"database/sql"
"database/sql/driver"
"sync"
"code.gitea.io/gitea/modules/setting"
"github.com/lib/pq"
"xorm.io/xorm/dialects"
)
var registerOnce sync.Once
func registerPostgresSchemaDriver() {
registerOnce.Do(func() {
sql.Register("postgresschema", &postgresSchemaDriver{})
dialects.RegisterDriver("postgresschema", dialects.QueryDriver("postgres"))
})
}
type postgresSchemaDriver struct {
pq.Driver
}
// Open opens a new connection to the database. name is a connection string.
// This function opens the postgres connection in the default manner but immediately
// runs set_config to set the search_path appropriately
func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) {
conn, err := d.Driver.Open(name)
if err != nil {
return conn, err
}
schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema)
// golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here
// and in any case pq does not implement it
if execer, ok := conn.(driver.Execer); ok { //nolint
_, err := execer.Exec(`SELECT set_config(
'search_path',
$1 || ',' || current_setting('search_path'),
false)`, []driver.Value{schemaValue}) //nolint
if err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
stmt, err := conn.Prepare(`SELECT set_config(
'search_path',
$1 || ',' || current_setting('search_path'),
false)`)
if err != nil {
_ = conn.Close()
return nil, err
}
defer stmt.Close()
// driver.String.ConvertValue will never return err for string
// golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here
_, err = stmt.Exec([]driver.Value{schemaValue}) //nolint
if err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}

18
models/db/store.go Normal file
View File

@@ -0,0 +1,18 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"github.com/lafriks/xormstore"
)
// CreateStore creates a xormstore for the provided table and key
func CreateStore(table, key string) (*xormstore.Store, error) {
store, err := xormstore.NewOptions(x, xormstore.Options{
TableName: table,
}, []byte(key))
return store, err
}

112
models/db/test_fixtures.go Normal file
View File

@@ -0,0 +1,112 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"os"
"time"
"github.com/go-testfixtures/testfixtures/v3"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
var fixtures *testfixtures.Loader
// InitFixtures initialize test fixtures for a test database
func InitFixtures(dir string, engine ...*xorm.Engine) (err error) {
e := x
if len(engine) == 1 {
e = engine[0]
}
testfiles := testfixtures.Directory(dir)
dialect := "unknown"
switch e.Dialect().URI().DBType {
case schemas.POSTGRES:
dialect = "postgres"
case schemas.MYSQL:
dialect = "mysql"
case schemas.MSSQL:
dialect = "mssql"
case schemas.SQLITE:
dialect = "sqlite3"
default:
fmt.Println("Unsupported RDBMS for integration tests")
os.Exit(1)
}
loaderOptions := []func(loader *testfixtures.Loader) error{
testfixtures.Database(e.DB().DB),
testfixtures.Dialect(dialect),
testfixtures.DangerousSkipTestDatabaseCheck(),
testfiles,
}
if e.Dialect().URI().DBType == schemas.POSTGRES {
loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
}
fixtures, err = testfixtures.New(loaderOptions...)
if err != nil {
return err
}
return err
}
// LoadFixtures load fixtures for a test database
func LoadFixtures(engine ...*xorm.Engine) error {
e := x
if len(engine) == 1 {
e = engine[0]
}
var err error
// Database transaction conflicts could occur and result in ROLLBACK
// As a simple workaround, we just retry 20 times.
for i := 0; i < 20; i++ {
err = fixtures.Load()
if err == nil {
break
}
time.Sleep(200 * time.Millisecond)
}
if err != nil {
fmt.Printf("LoadFixtures failed after retries: %v\n", err)
}
// Now if we're running postgres we need to tell it to update the sequences
if e.Dialect().URI().DBType == schemas.POSTGRES {
results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) ||
', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' ||
quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';'
FROM pg_class AS S,
pg_depend AS D,
pg_class AS T,
pg_attribute AS C,
pg_tables AS PGT
WHERE S.relkind = 'S'
AND S.oid = D.objid
AND D.refobjid = T.oid
AND D.refobjid = C.attrelid
AND D.refobjsubid = C.attnum
AND T.relname = PGT.tablename
ORDER BY S.relname;`)
if err != nil {
fmt.Printf("Failed to generate sequence update: %v\n", err)
return err
}
for _, r := range results {
for _, value := range r {
_, err = e.Exec(value)
if err != nil {
fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err)
return err
}
}
}
}
return err
}

230
models/db/unit_tests.go Normal file
View File

@@ -0,0 +1,230 @@
// Copyright 2016 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"io/ioutil"
"math"
"net/url"
"os"
"path/filepath"
"testing"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage"
"code.gitea.io/gitea/modules/util"
"github.com/stretchr/testify/assert"
"xorm.io/xorm"
"xorm.io/xorm/names"
)
// NonexistentID an ID that will never exist
const NonexistentID = int64(math.MaxInt64)
// giteaRoot a path to the gitea root
var (
giteaRoot string
fixturesDir string
)
// FixturesDir returns the fixture directory
func FixturesDir() string {
return fixturesDir
}
func fatalTestError(fmtStr string, args ...interface{}) {
fmt.Fprintf(os.Stderr, fmtStr, args...)
os.Exit(1)
}
// MainTest a reusable TestMain(..) function for unit tests that need to use a
// test database. Creates the test database, and sets necessary settings.
func MainTest(m *testing.M, pathToGiteaRoot string) {
var err error
giteaRoot = pathToGiteaRoot
fixturesDir = filepath.Join(pathToGiteaRoot, "models", "fixtures")
if err = CreateTestEngine(fixturesDir); err != nil {
fatalTestError("Error creating test engine: %v\n", err)
}
setting.AppURL = "https://try.gitea.io/"
setting.RunUser = "runuser"
setting.SSH.Port = 3000
setting.SSH.Domain = "try.gitea.io"
setting.Database.UseSQLite3 = true
setting.RepoRootPath, err = ioutil.TempDir(os.TempDir(), "repos")
if err != nil {
fatalTestError("TempDir: %v\n", err)
}
setting.AppDataPath, err = ioutil.TempDir(os.TempDir(), "appdata")
if err != nil {
fatalTestError("TempDir: %v\n", err)
}
setting.AppWorkPath = pathToGiteaRoot
setting.StaticRootPath = pathToGiteaRoot
setting.GravatarSourceURL, err = url.Parse("https://secure.gravatar.com/avatar/")
if err != nil {
fatalTestError("url.Parse: %v\n", err)
}
setting.Attachment.Storage.Path = filepath.Join(setting.AppDataPath, "attachments")
setting.LFS.Storage.Path = filepath.Join(setting.AppDataPath, "lfs")
setting.Avatar.Storage.Path = filepath.Join(setting.AppDataPath, "avatars")
setting.RepoAvatar.Storage.Path = filepath.Join(setting.AppDataPath, "repo-avatars")
setting.RepoArchive.Storage.Path = filepath.Join(setting.AppDataPath, "repo-archive")
if err = storage.Init(); err != nil {
fatalTestError("storage.Init: %v\n", err)
}
if err = util.RemoveAll(setting.RepoRootPath); err != nil {
fatalTestError("util.RemoveAll: %v\n", err)
}
if err = util.CopyDir(filepath.Join(pathToGiteaRoot, "integrations", "gitea-repositories-meta"), setting.RepoRootPath); err != nil {
fatalTestError("util.CopyDir: %v\n", err)
}
exitStatus := m.Run()
if err = util.RemoveAll(setting.RepoRootPath); err != nil {
fatalTestError("util.RemoveAll: %v\n", err)
}
if err = util.RemoveAll(setting.AppDataPath); err != nil {
fatalTestError("util.RemoveAll: %v\n", err)
}
os.Exit(exitStatus)
}
// CreateTestEngine creates a memory database and loads the fixture data from fixturesDir
func CreateTestEngine(fixturesDir string) error {
var err error
x, err = xorm.NewEngine("sqlite3", "file::memory:?cache=shared&_txlock=immediate")
if err != nil {
return err
}
x.SetMapper(names.GonicMapper{})
if err = syncTables(); err != nil {
return err
}
switch os.Getenv("GITEA_UNIT_TESTS_VERBOSE") {
case "true", "1":
x.ShowSQL(true)
}
return InitFixtures(fixturesDir)
}
// PrepareTestDatabase load test fixtures into test database
func PrepareTestDatabase() error {
return LoadFixtures()
}
// PrepareTestEnv prepares the environment for unit tests. Can only be called
// by tests that use the above MainTest(..) function.
func PrepareTestEnv(t testing.TB) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, util.RemoveAll(setting.RepoRootPath))
metaPath := filepath.Join(giteaRoot, "integrations", "gitea-repositories-meta")
assert.NoError(t, util.CopyDir(metaPath, setting.RepoRootPath))
base.SetupGiteaRoot() // Makes sure GITEA_ROOT is set
}
type testCond struct {
query interface{}
args []interface{}
}
// Cond create a condition with arguments for a test
func Cond(query interface{}, args ...interface{}) interface{} {
return &testCond{query: query, args: args}
}
func whereConditions(sess *xorm.Session, conditions []interface{}) {
for _, condition := range conditions {
switch cond := condition.(type) {
case *testCond:
sess.Where(cond.query, cond.args...)
default:
sess.Where(cond)
}
}
}
// LoadBeanIfExists loads beans from fixture database if exist
func LoadBeanIfExists(bean interface{}, conditions ...interface{}) (bool, error) {
return loadBeanIfExists(bean, conditions...)
}
func loadBeanIfExists(bean interface{}, conditions ...interface{}) (bool, error) {
sess := x.NewSession()
defer sess.Close()
whereConditions(sess, conditions)
return sess.Get(bean)
}
// BeanExists for testing, check if a bean exists
func BeanExists(t testing.TB, bean interface{}, conditions ...interface{}) bool {
exists, err := loadBeanIfExists(bean, conditions...)
assert.NoError(t, err)
return exists
}
// AssertExistsAndLoadBean assert that a bean exists and load it from the test
// database
func AssertExistsAndLoadBean(t testing.TB, bean interface{}, conditions ...interface{}) interface{} {
exists, err := loadBeanIfExists(bean, conditions...)
assert.NoError(t, err)
assert.True(t, exists,
"Expected to find %+v (of type %T, with conditions %+v), but did not",
bean, bean, conditions)
return bean
}
// GetCount get the count of a bean
func GetCount(t testing.TB, bean interface{}, conditions ...interface{}) int {
sess := x.NewSession()
defer sess.Close()
whereConditions(sess, conditions)
count, err := sess.Count(bean)
assert.NoError(t, err)
return int(count)
}
// AssertNotExistsBean assert that a bean does not exist in the test database
func AssertNotExistsBean(t testing.TB, bean interface{}, conditions ...interface{}) {
exists, err := loadBeanIfExists(bean, conditions...)
assert.NoError(t, err)
assert.False(t, exists)
}
// AssertExistsIf asserts that a bean exists or does not exist, depending on
// what is expected.
func AssertExistsIf(t *testing.T, expected bool, bean interface{}, conditions ...interface{}) {
exists, err := loadBeanIfExists(bean, conditions...)
assert.NoError(t, err)
assert.Equal(t, expected, exists)
}
// AssertSuccessfulInsert assert that beans is successfully inserted
func AssertSuccessfulInsert(t testing.TB, beans ...interface{}) {
_, err := x.Insert(beans...)
assert.NoError(t, err)
}
// AssertCount assert the count of a bean
func AssertCount(t testing.TB, bean, expected interface{}) {
assert.EqualValues(t, expected, GetCount(t, bean))
}
// AssertInt64InRange assert value is in range [low, high]
func AssertInt64InRange(t testing.TB, low, high, value int64) {
assert.True(t, value >= low && value <= high,
"Expected value in range [%d, %d], found %d", low, high, value)
}