1
1
mirror of https://github.com/go-gitea/gitea synced 2025-07-22 18:28:37 +00:00

Upgrade xorm to v1.1.0 (#15869)

This commit is contained in:
Lunny Xiao
2021-05-15 03:17:06 +08:00
committed by GitHub
parent e2f39c2b64
commit f6be429781
55 changed files with 1309 additions and 438 deletions

View File

@@ -79,32 +79,34 @@ type Base struct {
quoter schemas.Quoter
}
func (b *Base) Quoter() schemas.Quoter {
return b.quoter
// Quoter returns the current database Quoter
func (db *Base) Quoter() schemas.Quoter {
return db.quoter
}
func (b *Base) Init(dialect Dialect, uri *URI) error {
b.dialect, b.uri = dialect, uri
// Init initialize the dialect
func (db *Base) Init(dialect Dialect, uri *URI) error {
db.dialect, db.uri = dialect, uri
return nil
}
func (b *Base) URI() *URI {
return b.uri
// URI returns the uri of database
func (db *Base) URI() *URI {
return db.uri
}
func (b *Base) DBType() schemas.DBType {
return b.uri.DBType
}
func (b *Base) FormatBytes(bs []byte) string {
// FormatBytes formats bytes
func (db *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
}
// HasRecords returns true if the SQL has records returned
func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) {
rows, err := queryer.QueryContext(ctx, query, args...)
if err != nil {
@@ -118,6 +120,7 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri
return false, nil
}
// IsColumnExist returns true if the column of the table exist
func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
quote := db.dialect.Quoter().Quote
query := fmt.Sprintf(
@@ -132,11 +135,13 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName)
}
// AddColumnSQL returns a SQL to add a column
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true)
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s)
}
// CreateIndexSQL returns a SQL to create index
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quoter := db.dialect.Quoter()
var unique string
@@ -150,6 +155,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quoter.Join(index.Cols, ","))
}
// DropIndexSQL returns a SQL to drop index
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
quote := db.dialect.Quoter().Quote
var name string
@@ -161,16 +167,19 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
}
// ModifyColumnSQL returns a SQL to modify SQL
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", tableName, s)
}
func (b *Base) ForUpdateSQL(query string) string {
// ForUpdateSQL returns for updateSQL
func (db *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}
func (b *Base) SetParams(params map[string]string) {
// SetParams set params
func (db *Base) SetParams(params map[string]string) {
}
var (
@@ -206,6 +215,7 @@ func regDrvsNDialects() bool {
"postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
"pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
"goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }},
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
)
// Driver represents a database driver
type Driver interface {
Parse(string, string) (*URI, error)
}
@@ -16,6 +17,7 @@ var (
drivers = map[string]Driver{}
)
// RegisterDriver register a driver
func RegisterDriver(driverName string, driver Driver) {
if driver == nil {
panic("core: Register driver is nil")
@@ -26,10 +28,12 @@ func RegisterDriver(driverName string, driver Driver) {
drivers[driverName] = driver
}
// QueryDriver query a driver with name
func QueryDriver(driverName string) Driver {
return drivers[driverName]
}
// RegisteredDriverSize returned all drivers's length
func RegisteredDriverSize() int {
return len(drivers)
}
@@ -38,7 +42,7 @@ func RegisteredDriverSize() int {
func OpenDialect(driverName, connstr string) (Dialect, error) {
driver := QueryDriver(driverName)
if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
return nil, fmt.Errorf("unsupported driver name: %v", driverName)
}
uri, err := driver.Parse(driverName, connstr)
@@ -48,7 +52,7 @@ func OpenDialect(driverName, connstr string) (Dialect, error) {
dialect := QueryDialect(uri.DBType)
if dialect == nil {
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
return nil, fmt.Errorf("unsupported dialect type: %v", uri.DBType)
}
dialect.Init(uri)

View File

@@ -38,6 +38,7 @@ func convertQuestionMark(sql, prefix string, start int) string {
return buf.String()
}
// Do implements Filter
func (s *SeqFilter) Do(sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start)
}

View File

@@ -284,7 +284,7 @@ func (db *mssql) SQLType(c *schemas.Column) string {
case schemas.TimeStampz:
res = "DATETIMEOFFSET"
c.Length = 7
case schemas.MediumInt:
case schemas.MediumInt, schemas.UnsignedInt:
res = schemas.Int
case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json:
res = db.defaultVarchar + "(MAX)"
@@ -296,7 +296,7 @@ func (db *mssql) SQLType(c *schemas.Column) string {
case schemas.TinyInt:
res = schemas.TinyInt
c.Length = 0
case schemas.BigInt:
case schemas.BigInt, schemas.UnsignedBigInt:
res = schemas.BigInt
c.Length = 0
case schemas.NVarchar:

View File

@@ -254,6 +254,10 @@ func (db *mysql) SQLType(c *schemas.Column) string {
c.Length = 40
case schemas.Json:
res = schemas.Text
case schemas.UnsignedInt:
res = schemas.Int
case schemas.UnsignedBigInt:
res = schemas.BigInt
default:
res = t
}
@@ -271,6 +275,11 @@ func (db *mysql) SQLType(c *schemas.Column) string {
} else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")"
}
if c.SQLType.Name == schemas.UnsignedBigInt || c.SQLType.Name == schemas.UnsignedInt {
res += " UNSIGNED"
}
return res
}
@@ -331,16 +340,16 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var columnName, isNullable, colType, colKey, extra, comment string
var alreadyQuoted bool
var columnName, nullableStr, colType, colKey, extra, comment string
var alreadyQuoted, isUnsigned bool
var colDefault *string
err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra, &comment, &alreadyQuoted)
err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &alreadyQuoted)
if err != nil {
return nil, nil, err
}
col.Name = strings.Trim(columnName, "` ")
col.Comment = comment
if "YES" == isNullable {
if nullableStr == "YES" {
col.Nullable = true
}
@@ -351,8 +360,15 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
col.DefaultIsEmpty = true
}
fields := strings.Fields(colType)
if len(fields) == 2 && fields[1] == "unsigned" {
isUnsigned = true
}
colType = fields[0]
cts := strings.Split(colType, "(")
colName := cts[0]
// Remove the /* mariadb-5.3 */ suffix from coltypes
colName = strings.TrimSuffix(colName, "/* mariadb-5.3 */")
colType = strings.ToUpper(colName)
var len1, len2 int
if len(cts) == 2 {
@@ -387,11 +403,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
}
}
}
if colType == "FLOAT UNSIGNED" {
colType = "FLOAT"
}
if colType == "DOUBLE UNSIGNED" {
colType = "DOUBLE"
if isUnsigned {
colType = "UNSIGNED " + colType
}
col.Length = len1
col.Length2 = len2

View File

@@ -824,6 +824,11 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
}
}
// FormatBytes formats bytes
func (db *postgres) FormatBytes(bs []byte) string {
return fmt.Sprintf("E'\\x%x'", bs)
}
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
@@ -833,12 +838,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
case schemas.Bit:
res = schemas.Boolean
return res
case schemas.MediumInt, schemas.Int, schemas.Integer:
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedInt:
if c.IsAutoIncrement {
return schemas.Serial
}
return schemas.Integer
case schemas.BigInt:
case schemas.BigInt, schemas.UnsignedBigInt:
if c.IsAutoIncrement {
return schemas.BigSerial
}
@@ -1052,6 +1057,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
}
}
if colDefault != nil && *colDefault == "unique_rowid()" { // ignore the system column added by cockroach
continue
}
col.Name = strings.Trim(colName, `" `)
if colDefault != nil {

View File

@@ -193,7 +193,8 @@ func (db *sqlite3) SQLType(c *schemas.Column) string {
case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText,
schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
return schemas.Text
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt:
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt,
schemas.UnsignedBigInt, schemas.UnsignedInt:
return schemas.Integer
case schemas.Float, schemas.Double, schemas.Real:
return schemas.Real

View File

@@ -19,7 +19,11 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
if dialect.URI().DBType == schemas.ORACLE {
v = t
} else {
v = t.Format("2006-01-02 15:04:05")
}
case schemas.TimeStampz:
if dialect.URI().DBType == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
@@ -34,6 +38,7 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}
return
}
// FormatColumnTime format column time
func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {