1
1
mirror of https://github.com/go-gitea/gitea synced 2025-07-22 18:28:37 +00:00

Upgrade xorm to v1.2.2 (#16663)

* Upgrade xorm to v1.2.2

* Change the Engine interface to match xorm v1.2.2
This commit is contained in:
Lunny Xiao
2021-08-13 07:11:42 +08:00
committed by GitHub
parent 5fbccad906
commit 7224cfc578
134 changed files with 42889 additions and 5428 deletions

View File

@@ -42,10 +42,12 @@ func (uri *URI) SetSchema(schema string) {
type Dialect interface {
Init(*URI) error
URI() *URI
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error)
SQLType(*schemas.Column) string
Alias(string) string // return what a sql type's alias of
ColumnTypeKind(string) int // database column type kind
IsReserved(string) bool
Quoter() schemas.Quoter
SetQuotePolicy(quotePolicy QuotePolicy)
@@ -80,6 +82,11 @@ type Base struct {
quoter schemas.Quoter
}
// Alias returned col itself
func (db *Base) Alias(col string) string {
return col
}
// Quoter returns the current database Quoter
func (db *Base) Quoter() schemas.Quoter {
return db.quoter
@@ -96,11 +103,6 @@ func (db *Base) URI() *URI {
return db.uri
}
// FormatBytes formats bytes
func (db *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote
@@ -118,7 +120,7 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri
if rows.Next() {
return true, nil
}
return false, nil
return false, rows.Err()
}
// IsColumnExist returns true if the column of the table exist

View File

@@ -5,12 +5,30 @@
package dialects
import (
"database/sql"
"fmt"
"time"
"xorm.io/xorm/core"
)
// ScanContext represents a context when Scan
type ScanContext struct {
DBLocation *time.Location
UserLocation *time.Location
}
// DriverFeatures represents driver feature
type DriverFeatures struct {
SupportReturnInsertedID bool
}
// Driver represents a database driver
type Driver interface {
Parse(string, string) (*URI, error)
Features() *DriverFeatures
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
}
var (
@@ -59,3 +77,9 @@ func OpenDialect(driverName, connstr string) (Dialect, error) {
return dialect, nil
}
type baseDriver struct{}
func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error {
return rows.Scan(v...)
}

View File

@@ -6,6 +6,7 @@ package dialects
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
@@ -263,6 +264,9 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve
var version, level, edition string
if !rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
return nil, errors.New("unknow version")
}
@@ -281,7 +285,7 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve
func (db *mssql) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case schemas.Bool:
case schemas.Bool, schemas.Boolean:
res = schemas.Bit
if strings.EqualFold(c.Default, "true") {
c.Default = "1"
@@ -299,17 +303,26 @@ func (db *mssql) SQLType(c *schemas.Column) string {
c.IsPrimaryKey = true
c.Nullable = false
res = schemas.BigInt
case schemas.Bytea, schemas.Blob, schemas.Binary, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
case schemas.Bytea, schemas.Binary:
res = schemas.VarBinary
if c.Length == 0 {
c.Length = 50
}
case schemas.TimeStamp:
res = schemas.DateTime
case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
res = schemas.VarBinary
if c.Length == 0 {
res += "(MAX)"
}
case schemas.TimeStamp, schemas.DateTime:
if c.Length > 3 {
res = "DATETIME2"
} else {
return schemas.DateTime
}
case schemas.TimeStampz:
res = "DATETIMEOFFSET"
c.Length = 7
case schemas.MediumInt:
case schemas.MediumInt, schemas.TinyInt, schemas.SmallInt, schemas.UnsignedMediumInt, schemas.UnsignedTinyInt, schemas.UnsignedSmallInt:
res = schemas.Int
case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json:
res = db.defaultVarchar + "(MAX)"
@@ -348,7 +361,7 @@ func (db *mssql) SQLType(c *schemas.Column) string {
res = t
}
if res == schemas.Int || res == schemas.Bit || res == schemas.DateTime {
if res == schemas.Int || res == schemas.Bit {
return res
}
@@ -363,6 +376,19 @@ func (db *mssql) SQLType(c *schemas.Column) string {
return res
}
func (db *mssql) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATE", "DATETIME", "DATETIME2", "TIME":
return schemas.TIME_TYPE
case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT":
return schemas.TEXT_TYPE
case "FLOAT", "REAL", "BIGINT", "DATETIMEOFFSET", "TINYINT", "SMALLINT", "INT":
return schemas.NUMERIC_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *mssql) IsReserved(name string) bool {
_, ok := mssqlReservedWords[strings.ToUpper(name)]
return ok
@@ -476,6 +502,12 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
col.Length /= 2
col.Length2 /= 2
}
case "DATETIME2":
col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 7, DefaultLength2: 0}
col.Length = scale
case "DATETIME":
col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 3, DefaultLength2: 0}
col.Length = scale
case "IMAGE":
col.SQLType = schemas.SQLType{Name: schemas.VarBinary, DefaultLength: 0, DefaultLength2: 0}
case "NCHAR":
@@ -495,6 +527,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
if rows.Err() != nil {
return nil, nil, rows.Err()
}
return colSeq, cols, nil
}
@@ -519,6 +554,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
table.Name = strings.Trim(name, "` ")
tables = append(tables, table)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return tables, nil
}
@@ -542,7 +580,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
indexes := make(map[string]*schemas.Index)
for rows.Next() {
var indexType int
var indexName, colName, isUnique string
@@ -581,6 +619,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
}
index.AddColumn(colName)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return indexes, nil
}
@@ -624,6 +665,13 @@ func (db *mssql) Filters() []Filter {
}
type odbcDriver struct {
baseDriver
}
func (p *odbcDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
@@ -640,8 +688,7 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
for _, c := range kv {
vv := strings.Split(strings.TrimSpace(c), "=")
if len(vv) == 2 {
switch strings.ToLower(vv[0]) {
case "database":
if strings.ToLower(vv[0]) == "database" {
dbName = vv[1]
}
}
@@ -652,3 +699,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
}
return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil
}
func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT":
fallthrough
case "DATE", "DATETIME", "DATETIME2", "TIME":
var s sql.NullString
return &s, nil
case "FLOAT", "REAL":
var s sql.NullFloat64
return &s, nil
case "BIGINT", "DATETIMEOFFSET":
var s sql.NullInt64
return &s, nil
case "TINYINT", "SMALLINT", "INT":
var s sql.NullInt32
return &s, nil
default:
var r sql.RawBytes
return &r, nil
}
}

180
vendor/xorm.io/xorm/dialects/mysql.go generated vendored
View File

@@ -7,6 +7,7 @@ package dialects
import (
"context"
"crypto/tls"
"database/sql"
"errors"
"fmt"
"regexp"
@@ -188,6 +189,21 @@ func (db *mysql) Init(uri *URI) error {
return db.Base.Init(db, uri)
}
var (
mysqlColAliases = map[string]string{
"numeric": "decimal",
}
)
// Alias returns a alias of column
func (db *mysql) Alias(col string) string {
v, ok := mysqlColAliases[strings.ToLower(col)]
if ok {
return v
}
return col
}
func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) {
rows, err := queryer.QueryContext(ctx, "SELECT @@VERSION")
if err != nil {
@@ -197,7 +213,10 @@ func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve
var version string
if !rows.Next() {
return nil, errors.New("Unknow version")
if rows.Err() != nil {
return nil, rows.Err()
}
return nil, errors.New("unknow version")
}
if err := rows.Scan(&version); err != nil {
@@ -238,15 +257,13 @@ func (db *mysql) SetParams(params map[string]string) {
fallthrough
case "COMPRESSED":
db.rowFormat = t
break
default:
break
}
}
}
func (db *mysql) SQLType(c *schemas.Column) string {
var res string
var isUnsigned bool
switch t := c.SQLType.Name; t {
case schemas.Bool:
res = schemas.TinyInt
@@ -293,8 +310,19 @@ func (db *mysql) SQLType(c *schemas.Column) string {
res = schemas.Text
case schemas.UnsignedInt:
res = schemas.Int
isUnsigned = true
case schemas.UnsignedBigInt:
res = schemas.BigInt
isUnsigned = true
case schemas.UnsignedMediumInt:
res = schemas.MediumInt
isUnsigned = true
case schemas.UnsignedSmallInt:
res = schemas.SmallInt
isUnsigned = true
case schemas.UnsignedTinyInt:
res = schemas.TinyInt
isUnsigned = true
default:
res = t
}
@@ -313,13 +341,28 @@ func (db *mysql) SQLType(c *schemas.Column) string {
res += "(" + strconv.Itoa(c.Length) + ")"
}
if c.SQLType.Name == schemas.UnsignedBigInt || c.SQLType.Name == schemas.UnsignedInt {
if isUnsigned {
res += " UNSIGNED"
}
return res
}
func (db *mysql) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME":
return schemas.TIME_TYPE
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET":
return schemas.TEXT_TYPE
case "BIGINT", "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "FLOAT", "REAL", "DOUBLE PRECISION", "DECIMAL", "NUMERIC", "BIT":
return schemas.NUMERIC_TYPE
case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB":
return schemas.BLOB_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *mysql) IsReserved(name string) bool {
_, ok := mysqlReservedWords[strings.ToUpper(name)]
return ok
@@ -472,6 +515,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
if rows.Err() != nil {
return nil, nil, rows.Err()
}
return colSeq, cols, nil
}
@@ -503,6 +549,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
table.StoreEngine = engine
tables = append(tables, table)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return tables, nil
}
@@ -533,7 +582,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
indexes := make(map[string]*schemas.Index)
for rows.Next() {
var indexType int
var indexName, colName, nonUnique string
@@ -546,7 +595,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName
continue
}
if "YES" == nonUnique || nonUnique == "1" {
if nonUnique == "YES" || nonUnique == "1" {
indexType = schemas.IndexType
} else {
indexType = schemas.UniqueType
@@ -570,6 +619,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName
}
index.AddColumn(colName)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return indexes, nil
}
@@ -630,7 +682,83 @@ func (db *mysql) Filters() []Filter {
return []Filter{}
}
type mysqlDriver struct {
baseDriver
}
func (p *mysqlDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: true,
}
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
// tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &URI{DBType: schemas.MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.DBName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")
for _, kv := range kvs {
splits := strings.Split(kv, "=")
if len(splits) == 2 {
if splits[0] == "charset" {
uri.Charset = splits[1]
}
}
}
}
}
}
return uri, nil
}
func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET":
var s sql.NullString
return &s, nil
case "BIGINT":
var s sql.NullInt64
return &s, nil
case "TINYINT", "SMALLINT", "MEDIUMINT", "INT":
var s sql.NullInt32
return &s, nil
case "FLOAT", "REAL", "DOUBLE PRECISION", "DOUBLE":
var s sql.NullFloat64
return &s, nil
case "DECIMAL", "NUMERIC":
var s sql.NullString
return &s, nil
case "DATETIME", "TIMESTAMP":
var s sql.NullTime
return &s, nil
case "BIT":
var s sql.RawBytes
return &s, nil
case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB":
var r sql.RawBytes
return &r, nil
default:
var r sql.RawBytes
return &r, nil
}
}
type mymysqlDriver struct {
mysqlDriver
}
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
@@ -681,41 +809,3 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return uri, nil
}
type mysqlDriver struct {
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
// tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &URI{DBType: schemas.MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.DBName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")
for _, kv := range kvs {
splits := strings.Split(kv, "=")
if len(splits) == 2 {
switch splits[0] {
case "charset":
uri.Charset = splits[1]
}
}
}
}
}
}
return uri, nil
}

View File

@@ -6,6 +6,7 @@ package dialects
import (
"context"
"database/sql"
"errors"
"fmt"
"regexp"
@@ -524,6 +525,9 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
var version string
if !rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
return nil, errors.New("unknow version")
}
@@ -567,6 +571,21 @@ func (db *oracle) SQLType(c *schemas.Column) string {
return res
}
func (db *oracle) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATE":
return schemas.TIME_TYPE
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB":
return schemas.TEXT_TYPE
case "NUMBER":
return schemas.NUMERIC_TYPE
case "BLOB":
return schemas.BLOB_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *oracle) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
@@ -740,6 +759,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
if rows.Err() != nil {
return nil, nil, rows.Err()
}
return colSeq, cols, nil
}
@@ -764,6 +786,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem
tables = append(tables, table)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return tables, nil
}
@@ -778,7 +803,7 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
indexes := make(map[string]*schemas.Index)
for rows.Next() {
var indexType int
var indexName, colName, uniqueness string
@@ -813,6 +838,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam
}
index.AddColumn(colName)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return indexes, nil
}
@@ -823,9 +851,16 @@ func (db *oracle) Filters() []Filter {
}
type godrorDriver struct {
baseDriver
}
func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
func (g *godrorDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
@@ -837,8 +872,7 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error)
names := dsnPattern.SubexpNames()
for i, match := range matches {
switch names[i] {
case "dbname":
if names[i] == "dbname" {
db.DBName = match
}
}
@@ -848,12 +882,33 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error)
return db, nil
}
func (g *godrorDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB":
var s sql.NullString
return &s, nil
case "NUMBER":
var s sql.NullString
return &s, nil
case "DATE":
var s sql.NullTime
return &s, nil
case "BLOB":
var r sql.RawBytes
return &r, nil
default:
var r sql.RawBytes
return &r, nil
}
}
type oci8Driver struct {
godrorDriver
}
// dataSourceName=user/password@ipv4:port/dbname
// dataSourceName=user/password@[ipv6]:port/dbname
func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
func (o *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@
@@ -862,8 +917,7 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
matches := dsnPattern.FindStringSubmatch(dataSourceName)
names := dsnPattern.SubexpNames()
for i, match := range matches {
switch names[i] {
case "dbname":
if names[i] == "dbname" {
db.DBName = match
}
}

View File

@@ -6,6 +6,7 @@ package dialects
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
@@ -777,12 +778,24 @@ var (
var (
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
postgresColAliases = map[string]string{
"numeric": "decimal",
}
)
type postgres struct {
Base
}
// Alias returns a alias of column
func (db *postgres) Alias(col string) string {
v, ok := postgresColAliases[strings.ToLower(col)]
if ok {
return v
}
return col
}
func (db *postgres) Init(uri *URI) error {
db.quoter = postgresQuoter
return db.Base.Init(db, uri)
@@ -797,7 +810,10 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas
var version string
if !rows.Next() {
return nil, errors.New("Unknow version")
if rows.Err() != nil {
return nil, rows.Err()
}
return nil, errors.New("unknow version")
}
if err := rows.Scan(&version); err != nil {
@@ -860,21 +876,16 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
}
}
// FormatBytes formats bytes
func (db *postgres) FormatBytes(bs []byte) string {
return fmt.Sprintf("E'\\x%x'", bs)
}
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case schemas.TinyInt:
case schemas.TinyInt, schemas.UnsignedTinyInt:
res = schemas.SmallInt
return res
case schemas.Bit:
res = schemas.Boolean
return res
case schemas.MediumInt, schemas.Int, schemas.Integer:
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
if c.IsAutoIncrement {
return schemas.Serial
}
@@ -930,6 +941,21 @@ func (db *postgres) SQLType(c *schemas.Column) string {
return res
}
func (db *postgres) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME", "TIMESTAMP":
return schemas.TIME_TYPE
case "VARCHAR", "TEXT":
return schemas.TEXT_TYPE
case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
return schemas.NUMERIC_TYPE
case "BOOL":
return schemas.BOOL_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok
@@ -1039,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
}
defer rows.Close()
return rows.Next(), nil
if rows.Next() {
return true, nil
}
return false, rows.Err()
}
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@@ -1169,7 +1198,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
}
}
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name)
return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name)
}
col.Length = maxLen
@@ -1177,19 +1206,22 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
if !col.DefaultIsEmpty {
if col.SQLType.IsText() {
if strings.HasSuffix(col.Default, "::character varying") {
col.Default = strings.TrimRight(col.Default, "::character varying")
col.Default = strings.TrimSuffix(col.Default, "::character varying")
} else if !strings.HasPrefix(col.Default, "'") {
col.Default = "'" + col.Default + "'"
}
} else if col.SQLType.IsTime() {
if strings.HasSuffix(col.Default, "::timestamp without time zone") {
col.Default = strings.TrimRight(col.Default, "::timestamp without time zone")
col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone")
}
}
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
if rows.Err() != nil {
return nil, nil, rows.Err()
}
return colSeq, cols, nil
}
@@ -1220,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch
table.Name = name
tables = append(tables, table)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return tables, nil
}
@@ -1236,7 +1271,7 @@ func getIndexColName(indexdef string) []string {
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1"
if len(db.getSchema()) != 0 {
args = append(args, db.getSchema())
s = s + " AND schemaname=$2"
@@ -1248,7 +1283,7 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
indexes := make(map[string]*schemas.Index)
for rows.Next() {
var indexType int
var indexName, indexdef string
@@ -1290,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
index.IsRegular = isRegular
indexes[index.Name] = index
}
if rows.Err() != nil {
return nil, rows.Err()
}
return indexes, nil
}
@@ -1298,18 +1336,11 @@ func (db *postgres) Filters() []Filter {
}
type pqDriver struct {
baseDriver
}
type values map[string]string
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@@ -1329,30 +1360,94 @@ func parseURL(connstr string) (string, error) {
return "", nil
}
func parseOpts(name string, o values) error {
if len(name) == 0 {
return fmt.Errorf("invalid options: %s", name)
func parseOpts(urlStr string, o values) error {
if len(urlStr) == 0 {
return fmt.Errorf("invalid options: %s", urlStr)
}
name = strings.TrimSpace(name)
urlStr = strings.TrimSpace(urlStr)
ps := strings.Split(name, " ")
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
return fmt.Errorf("invalid option: %q", p)
var (
inQuote bool
state int // 0 key, 1 space, 2 value, 3 equal
start int
key string
)
for i, c := range urlStr {
switch c {
case ' ':
if !inQuote {
if state == 2 {
state = 1
v := urlStr[start:i]
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
v = v[1 : len(v)-1]
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
}
o[key] = v
} else if state != 1 {
return fmt.Errorf("wrong format: %v", urlStr)
}
}
case '\'':
if state == 3 {
state = 2
start = i
} else if state != 2 {
return fmt.Errorf("wrong format: %v", urlStr)
}
inQuote = !inQuote
case '=':
if !inQuote {
if state != 0 {
return fmt.Errorf("wrong format: %v", urlStr)
}
key = urlStr[start:i]
state = 3
}
default:
if state == 3 {
state = 2
start = i
} else if state == 1 {
state = 0
start = i
}
}
if i == len(urlStr)-1 {
if state != 2 {
return errors.New("no value matched key")
}
v := urlStr[start : i+1]
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
v = v[1 : len(v)-1]
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
}
o[key] = v
}
o.Set(kv[0], kv[1])
}
return nil
}
func (p *pqDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.POSTGRES}
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
var err error
if strings.Contains(dataSourceName, "://") {
if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") {
return nil, fmt.Errorf("unsupported protocol %v", dataSourceName)
}
db.DBName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
@@ -1364,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return nil, err
}
db.DBName = o.Get("dbname")
db.DBName = o["dbname"]
}
if db.DBName == "" {
@@ -1374,6 +1469,32 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return db, nil
}
func (p *pqDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "VARCHAR", "TEXT":
var s sql.NullString
return &s, nil
case "BIGINT", "BIGSERIAL":
var s sql.NullInt64
return &s, nil
case "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL":
var s sql.NullInt32
return &s, nil
case "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
var s sql.NullFloat64
return &s, nil
case "DATETIME", "TIMESTAMP":
var s sql.NullTime
return &s, nil
case "BOOL":
var s sql.NullBool
return &s, nil
default:
var r sql.RawBytes
return &r, nil
}
}
type pqDriverPgx struct {
pqDriver
}
@@ -1401,6 +1522,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri
parts := strings.Split(defaultSchema, ",")
return strings.TrimSpace(parts[len(parts)-1]), nil
}
if rows.Err() != nil {
return "", rows.Err()
}
return "", errors.New("No default schema")
return "", errors.New("no default schema")
}

View File

@@ -169,7 +169,10 @@ func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas.
var version string
if !rows.Next() {
return nil, errors.New("Unknow version")
if rows.Err() != nil {
return nil, rows.Err()
}
return nil, errors.New("unknow version")
}
if err := rows.Scan(&version); err != nil {
@@ -214,8 +217,9 @@ func (db *sqlite3) SQLType(c *schemas.Column) string {
case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText,
schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
return schemas.Text
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt,
schemas.UnsignedBigInt, schemas.UnsignedInt:
case schemas.Bit, schemas.TinyInt, schemas.UnsignedTinyInt, schemas.SmallInt,
schemas.UnsignedSmallInt, schemas.MediumInt, schemas.Int, schemas.UnsignedInt,
schemas.BigInt, schemas.UnsignedBigInt, schemas.Integer:
return schemas.Integer
case schemas.Float, schemas.Double, schemas.Real:
return schemas.Real
@@ -233,8 +237,19 @@ func (db *sqlite3) SQLType(c *schemas.Column) string {
}
}
func (db *sqlite3) FormatBytes(bs []byte) string {
return fmt.Sprintf("X'%x'", bs)
func (db *sqlite3) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME":
return schemas.TIME_TYPE
case "TEXT":
return schemas.TEXT_TYPE
case "INTEGER", "REAL", "NUMERIC", "DECIMAL":
return schemas.NUMERIC_TYPE
case "BLOB":
return schemas.BLOB_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *sqlite3) IsReserved(name string) bool {
@@ -404,12 +419,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa
defer rows.Close()
var name string
for rows.Next() {
if rows.Next() {
err = rows.Scan(&name)
if err != nil {
return nil, nil, err
}
break
}
if rows.Err() != nil {
return nil, nil, rows.Err()
}
if name == "" {
@@ -472,6 +489,9 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche
}
tables = append(tables, table)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return tables, nil
}
@@ -485,7 +505,7 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
indexes := make(map[string]*schemas.Index)
for rows.Next() {
var tmpSQL sql.NullString
err = rows.Scan(&tmpSQL)
@@ -531,6 +551,9 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa
index.IsRegular = isRegular
indexes[index.Name] = index
}
if rows.Err() != nil {
return nil, rows.Err()
}
return indexes, nil
}
@@ -540,6 +563,13 @@ func (db *sqlite3) Filters() []Filter {
}
type sqlite3Driver struct {
baseDriver
}
func (p *sqlite3Driver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: true,
}
}
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
@@ -549,3 +579,29 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil
}
func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "TEXT":
var s sql.NullString
return &s, nil
case "INTEGER":
var s sql.NullInt64
return &s, nil
case "DATETIME":
var s sql.NullTime
return &s, nil
case "REAL":
var s sql.NullFloat64
return &s, nil
case "NUMERIC", "DECIMAL":
var s sql.NullString
return &s, nil
case "BLOB":
var s sql.RawBytes
return &s, nil
default:
var r sql.NullString
return &r, nil
}
}

73
vendor/xorm.io/xorm/dialects/time.go generated vendored
View File

@@ -5,50 +5,57 @@
package dialects
import (
"strings"
"time"
"xorm.io/xorm/schemas"
)
// FormatTime format time as column type
func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
if dialect.URI().DBType == schemas.ORACLE {
v = t
} else {
v = t.Format("2006-01-02 15:04:05")
}
case schemas.TimeStampz:
if dialect.URI().DBType == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
}
// FormatColumnTime format column time
func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) {
func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.Column, t time.Time) (interface{}, error) {
if t.IsZero() {
if col.Nullable {
return nil
return nil, nil
}
if col.SQLType.IsNumeric() {
return 0, nil
}
return ""
}
var tmZone = dbLocation
if col.TimeZone != nil {
return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone))
tmZone = col.TimeZone
}
t = t.In(tmZone)
switch col.SQLType.Name {
case schemas.Date:
return t.Format("2006-01-02"), nil
case schemas.Time:
var layout = "15:04:05"
if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length)
}
return t.Format(layout), nil
case schemas.DateTime, schemas.TimeStamp:
var layout = "2006-01-02 15:04:05"
if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length)
}
return t.Format(layout), nil
case schemas.Varchar:
return t.Format("2006-01-02 15:04:05"), nil
case schemas.TimeStampz:
if dialect.URI().DBType == schemas.MSSQL {
return t.Format("2006-01-02T15:04:05.9999999Z07:00"), nil
} else {
return t.Format(time.RFC3339Nano), nil
}
case schemas.BigInt, schemas.Int:
return t.Unix(), nil
default:
return t, nil
}
return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone))
}