1
1
mirror of https://github.com/go-gitea/gitea synced 2024-12-23 09:04:26 +00:00

update xorm for fixing bug on processor BeforeSet and AfterSet when Find a map (#987)

This commit is contained in:
Lunny Xiao 2017-02-20 19:33:10 +08:00 committed by GitHub
parent 04fdeb9d8d
commit c5f8b96dda
7 changed files with 363 additions and 153 deletions

249
vendor/github.com/go-xorm/xorm/convert.go generated vendored Normal file
View File

@ -0,0 +1,249 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strconv"
"time"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
} else {
c := make([]byte, len(b))
copy(c, b)
return c
}
}
func asString(src interface{}) string {
switch v := src.(type) {
case string:
return v
case []byte:
return string(v)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
case reflect.Bool:
return strconv.FormatBool(rv.Bool())
}
return fmt.Sprintf("%v", src)
}
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.AppendInt(buf, rv.Int(), 10), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.AppendUint(buf, rv.Uint(), 10), true
case reflect.Float32:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
case reflect.Float64:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
case reflect.Bool:
return strconv.AppendBool(buf, rv.Bool()), true
case reflect.String:
s := rv.String()
return append(buf, s...), true
}
return
}
// convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type.
func convertAssign(dest, src interface{}) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = s
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s)
return nil
}
case []byte:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = string(s)
return nil
case *interface{}:
if d == nil {
return errNilPtr
}
*d = cloneBytes(s)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = cloneBytes(s)
return nil
}
case time.Time:
switch d := dest.(type) {
case *string:
*d = s.Format(time.RFC3339Nano)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s.Format(time.RFC3339Nano))
return nil
}
case nil:
switch d := dest.(type) {
case *interface{}:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = nil
return nil
}
}
var sv reflect.Value
switch d := dest.(type) {
case *string:
sv = reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
*d = asString(src)
return nil
}
case *[]byte:
sv = reflect.ValueOf(src)
if b, ok := asBytes(nil, sv); ok {
*d = b
return nil
}
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
*d = bv.(bool)
}
return err
case *interface{}:
*d = src
return nil
}
dpv := reflect.ValueOf(dest)
if dpv.Kind() != reflect.Ptr {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return errNilPtr
}
if !sv.IsValid() {
sv = reflect.ValueOf(src)
}
dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(cloneBytes(b)))
default:
dv.Set(sv)
}
return nil
}
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
dv.Set(sv.Convert(dv.Type()))
return nil
}
switch dv.Kind() {
case reflect.Ptr:
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
return nil
} else {
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
case reflect.String:
dv.SetString(asString(src))
return nil
}
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}

View File

@ -17,74 +17,83 @@ import (
) )
// str2PK convert string value to primary key value according to tp // str2PK convert string value to primary key value according to tp
func str2PK(s string, tp reflect.Type) (interface{}, error) { func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) {
var err error var err error
var result interface{} var result interface{}
var defReturn = reflect.Zero(tp)
switch tp.Kind() { switch tp.Kind() {
case reflect.Int: case reflect.Int:
result, err = strconv.Atoi(s) result, err = strconv.Atoi(s)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as int: " + err.Error()) return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error())
} }
case reflect.Int8: case reflect.Int8:
x, err := strconv.Atoi(s) x, err := strconv.Atoi(s)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as int16: " + err.Error()) return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error())
} }
result = int8(x) result = int8(x)
case reflect.Int16: case reflect.Int16:
x, err := strconv.Atoi(s) x, err := strconv.Atoi(s)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as int16: " + err.Error()) return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error())
} }
result = int16(x) result = int16(x)
case reflect.Int32: case reflect.Int32:
x, err := strconv.Atoi(s) x, err := strconv.Atoi(s)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as int32: " + err.Error()) return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error())
} }
result = int32(x) result = int32(x)
case reflect.Int64: case reflect.Int64:
result, err = strconv.ParseInt(s, 10, 64) result, err = strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as int64: " + err.Error()) return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error())
} }
case reflect.Uint: case reflect.Uint:
x, err := strconv.ParseUint(s, 10, 64) x, err := strconv.ParseUint(s, 10, 64)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as uint: " + err.Error()) return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error())
} }
result = uint(x) result = uint(x)
case reflect.Uint8: case reflect.Uint8:
x, err := strconv.ParseUint(s, 10, 64) x, err := strconv.ParseUint(s, 10, 64)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as uint8: " + err.Error()) return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error())
} }
result = uint8(x) result = uint8(x)
case reflect.Uint16: case reflect.Uint16:
x, err := strconv.ParseUint(s, 10, 64) x, err := strconv.ParseUint(s, 10, 64)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as uint16: " + err.Error()) return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error())
} }
result = uint16(x) result = uint16(x)
case reflect.Uint32: case reflect.Uint32:
x, err := strconv.ParseUint(s, 10, 64) x, err := strconv.ParseUint(s, 10, 64)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as uint32: " + err.Error()) return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error())
} }
result = uint32(x) result = uint32(x)
case reflect.Uint64: case reflect.Uint64:
result, err = strconv.ParseUint(s, 10, 64) result, err = strconv.ParseUint(s, 10, 64)
if err != nil { if err != nil {
return nil, errors.New("convert " + s + " as uint64: " + err.Error()) return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error())
} }
case reflect.String: case reflect.String:
result = s result = s
default: default:
panic("unsupported convert type") return defReturn, errors.New("unsupported convert type")
} }
result = reflect.ValueOf(result).Convert(tp).Interface() return reflect.ValueOf(result).Convert(tp), nil
return result, nil }
func str2PK(s string, tp reflect.Type) (interface{}, error) {
v, err := str2PKValue(s, tp)
if err != nil {
return nil, err
}
return v.Interface(), nil
} }
func splitTag(tag string) (tags []string) { func splitTag(tag string) (tags []string) {

View File

@ -114,7 +114,8 @@ func (rows *Rows) Scan(bean interface{}) error {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
} }
return rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean) _, err := rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean)
return err
} }
// Close session if session.IsAutoClose is true, and claimed any opened resources // Close session if session.IsAutoClose is true, and claimed any opened resources

View File

@ -386,52 +386,6 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) {
} }
} }
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := rValue(obj)
if dataStruct.Kind() != reflect.Struct {
return errors.New("Expected a pointer to a struct")
}
var col *core.Column
session.Statement.setRefValue(dataStruct)
table := session.Statement.RefTable
tableName := session.Statement.tableName
for key, data := range objMap {
if col = table.GetColumn(key); col == nil {
session.Engine.logger.Warnf("struct %v's has not field %v. %v",
table.Type.Name(), key, table.ColumnsSeq())
continue
}
fieldName := col.FieldName
fieldPath := strings.Split(fieldName, ".")
var fieldValue reflect.Value
if len(fieldPath) > 2 {
session.Engine.logger.Error("Unsupported mutliderive", fieldName)
continue
} else if len(fieldPath) == 2 {
parentField := dataStruct.FieldByName(fieldPath[0])
if parentField.IsValid() {
fieldValue = parentField.FieldByName(fieldPath[1])
}
} else {
fieldValue = dataStruct.FieldByName(fieldName)
}
if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", tableName, key)
continue
}
err := session.bytes2Value(col, &fieldValue, data)
if err != nil {
return err
}
}
return nil
}
func (session *Session) canCache() bool { func (session *Session) canCache() bool {
if session.Statement.RefTable == nil || if session.Statement.RefTable == nil ||
session.Statement.JoinStr != "" || session.Statement.JoinStr != "" ||
@ -485,24 +439,28 @@ type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int, func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int,
table *core.Table, newElemFunc func() reflect.Value, table *core.Table, newElemFunc func() reflect.Value,
sliceValueSetFunc func(*reflect.Value)) error { sliceValueSetFunc func(*reflect.Value, core.PK) error) error {
for rows.Next() { for rows.Next() {
var newValue = newElemFunc() var newValue = newElemFunc()
bean := newValue.Interface() bean := newValue.Interface()
dataStruct := rValue(bean) dataStruct := rValue(bean)
err := session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table) pk, err := session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table)
if err != nil {
return err
}
err = sliceValueSetFunc(&newValue, pk)
if err != nil { if err != nil {
return err return err
} }
sliceValueSetFunc(&newValue)
} }
return nil return nil
} }
func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) error { func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) (core.PK, error) {
dataStruct := rValue(bean) dataStruct := rValue(bean)
if dataStruct.Kind() != reflect.Struct { if dataStruct.Kind() != reflect.Struct {
return errors.New("Expected a pointer to a struct") return nil, errors.New("Expected a pointer to a struct")
} }
session.Statement.setRefValue(dataStruct) session.Statement.setRefValue(dataStruct)
@ -510,14 +468,14 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
return session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, session.Statement.RefTable) return session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, session.Statement.RefTable)
} }
func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) error { func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) {
scanResults := make([]interface{}, fieldsCount) scanResults := make([]interface{}, fieldsCount)
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var cell interface{} var cell interface{}
scanResults[i] = &cell scanResults[i] = &cell
} }
if err := rows.Scan(scanResults...); err != nil { if err := rows.Scan(scanResults...); err != nil {
return err return nil, err
} }
if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet {
@ -535,6 +493,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
}() }()
var tempMap = make(map[string]int) var tempMap = make(map[string]int)
var pk core.PK
for ii, key := range fields { for ii, key := range fields {
var idx int var idx int
var ok bool var ok bool
@ -579,10 +538,12 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
rawValueType := reflect.TypeOf(rawValue.Interface()) rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface())
col := table.GetColumnIdx(key, idx)
if col.IsPrimaryKey {
pk = append(pk, rawValue.Interface())
}
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
hasAssigned := false hasAssigned := false
col := table.GetColumnIdx(key, idx)
if col.SQLType.IsJson() { if col.SQLType.IsJson() {
var bs []byte var bs []byte
@ -591,7 +552,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
} else if rawValueType.ConvertibleTo(core.BytesType) { } else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes() bs = vv.Bytes()
} else { } else {
return fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
} }
hasAssigned = true hasAssigned = true
@ -601,14 +562,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(key, err) session.Engine.logger.Error(key, err)
return err return nil, err
} }
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface()) err := json.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(key, err) session.Engine.logger.Error(key, err)
return err return nil, err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} }
@ -633,14 +594,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.Engine.logger.Error(err)
return err return nil, err
} }
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface()) err := json.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.Engine.logger.Error(err)
return err return nil, err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} }
@ -772,7 +733,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
err := json.Unmarshal([]byte(vv.String()), x.Interface()) err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.Engine.logger.Error(err)
return err return nil, err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} }
@ -783,7 +744,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
err := json.Unmarshal(vv.Bytes(), x.Interface()) err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil { if err != nil {
session.Engine.logger.Error(err) session.Engine.logger.Error(err)
return err return nil, err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} }
@ -835,14 +796,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
defer newsession.Close() defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return err return nil, err
} }
if has { if has {
//v := structInter.Elem().Interface() //v := structInter.Elem().Interface()
//fieldValue.Set(reflect.ValueOf(v)) //fieldValue.Set(reflect.ValueOf(v))
fieldValue.Set(structInter.Elem()) fieldValue.Set(structInter.Elem())
} else { } else {
return errors.New("cascade obj is not exist") return nil, errors.New("cascade obj is not exist")
} }
} }
} else { } else {
@ -982,7 +943,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
} }
} }
} }
return nil return pk, nil
} }
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {

View File

@ -43,14 +43,12 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
session.Statement.setRefValue(pv.Elem()) session.Statement.setRefValue(pv.Elem())
} else { } else {
//return errors.New("slice type")
tp = tpNonStruct tp = tpNonStruct
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType) pv := reflect.New(sliceElementType)
session.Statement.setRefValue(pv.Elem()) session.Statement.setRefValue(pv.Elem())
} else { } else {
//return errors.New("slice type")
tp = tpNonStruct tp = tpNonStruct
} }
} }
@ -148,62 +146,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} }
if sliceValue.Kind() != reflect.Map { return session.noCacheFind(table, sliceValue, sqlStr, args...)
return session.noCacheFind(sliceValue, sqlStr, args...)
}
resultsSlice, err := session.query(sqlStr, args...)
if err != nil {
return err
}
keyType := sliceValue.Type().Key()
for _, results := range resultsSlice {
var newValue reflect.Value
if sliceElementType.Kind() == reflect.Ptr {
newValue = reflect.New(sliceElementType.Elem())
} else {
newValue = reflect.New(sliceElementType)
}
err := session.scanMapIntoStruct(newValue.Interface(), results)
if err != nil {
return err
}
var key interface{}
// if there is only one pk, we can put the id as map key.
if len(table.PrimaryKeys) == 1 {
key, err = str2PK(string(results[table.PrimaryKeys[0]]), keyType)
if err != nil {
return err
}
} else {
if keyType.Kind() != reflect.Slice {
panic("don't support multiple primary key's map has non-slice key type")
} else {
var keys core.PK = make([]interface{}, 0, len(table.PrimaryKeys))
for _, pk := range table.PrimaryKeys {
skey, err := str2PK(string(results[pk]), keyType)
if err != nil {
return err
}
keys = append(keys, skey)
}
key = keys
}
}
if sliceElementType.Kind() == reflect.Ptr {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface()))
} else {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface())))
}
}
return nil
} }
func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, args ...interface{}) error { func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
var rawRows *core.Rows var rawRows *core.Rows
var err error var err error
@ -224,27 +170,59 @@ func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, arg
} }
var newElemFunc func() reflect.Value var newElemFunc func() reflect.Value
sliceElementType := sliceValue.Type().Elem() elemType := containerValue.Type().Elem()
if sliceElementType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Ptr {
newElemFunc = func() reflect.Value { newElemFunc = func() reflect.Value {
return reflect.New(sliceElementType.Elem()) return reflect.New(elemType.Elem())
} }
} else { } else {
newElemFunc = func() reflect.Value { newElemFunc = func() reflect.Value {
return reflect.New(sliceElementType) return reflect.New(elemType)
} }
} }
var sliceValueSetFunc func(*reflect.Value) var containerValueSetFunc func(*reflect.Value, core.PK) error
if sliceValue.Kind() == reflect.Slice { if containerValue.Kind() == reflect.Slice {
if sliceElementType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Ptr {
sliceValueSetFunc = func(newValue *reflect.Value) { containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) containerValue.Set(reflect.Append(containerValue, reflect.ValueOf(newValue.Interface())))
return nil
} }
} else { } else {
sliceValueSetFunc = func(newValue *reflect.Value) { containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) containerValue.Set(reflect.Append(containerValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
return nil
}
}
} else {
keyType := containerValue.Type().Key()
if len(table.PrimaryKeys) == 0 {
return errors.New("don't support multiple primary key's map has non-slice key type")
}
if len(table.PrimaryKeys) > 1 && keyType.Kind() != reflect.Slice {
return errors.New("don't support multiple primary key's map has non-slice key type")
}
if elemType.Kind() == reflect.Ptr {
containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
keyValue := reflect.New(keyType)
err := convertPKToValue(table, keyValue.Interface(), pk)
if err != nil {
return err
}
containerValue.SetMapIndex(keyValue.Elem(), reflect.ValueOf(newValue.Interface()))
return nil
}
} else {
containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
keyValue := reflect.New(keyType)
err := convertPKToValue(table, keyValue.Interface(), pk)
if err != nil {
return err
}
containerValue.SetMapIndex(keyValue.Elem(), reflect.Indirect(reflect.ValueOf(newValue.Interface())))
return nil
} }
} }
} }
@ -252,7 +230,7 @@ func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, arg
var newValue = newElemFunc() var newValue = newElemFunc()
dataStruct := rValue(newValue.Interface()) dataStruct := rValue(newValue.Interface())
if dataStruct.Kind() == reflect.Struct { if dataStruct.Kind() == reflect.Struct {
return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, sliceValueSetFunc) return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, containerValueSetFunc)
} }
for rawRows.Next() { for rawRows.Next() {
@ -263,8 +241,20 @@ func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, arg
return err return err
} }
sliceValueSetFunc(&newValue) if err := containerValueSetFunc(&newValue, nil); err != nil {
return err
} }
}
return nil
}
func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error {
cols := table.PKColumns()
if len(cols) == 1 {
return convertAssign(dst, pk[0])
}
dst = pk
return nil return nil
} }

View File

@ -67,7 +67,7 @@ func (session *Session) nocacheGet(bean interface{}, sqlStr string, args ...inte
if rawRows.Next() { if rawRows.Next() {
fields, err := rawRows.Columns() fields, err := rawRows.Columns()
if err == nil { if err == nil {
err = session.row2Bean(rawRows, fields, len(fields), bean) _, err = session.row2Bean(rawRows, fields, len(fields), bean)
} }
return true, err return true, err
} }

6
vendor/vendor.json vendored
View File

@ -455,10 +455,10 @@
"revisionTime": "2016-08-11T02:11:45Z" "revisionTime": "2016-08-11T02:11:45Z"
}, },
{ {
"checksumSHA1": "BGWfs63vC5cJuxhVRrj+7YJKz7A=", "checksumSHA1": "COlm4o3G1rUSqr33iumtjY1qKD8=",
"path": "github.com/go-xorm/xorm", "path": "github.com/go-xorm/xorm",
"revision": "19f6dfc2e8c069adc624ca56cf8127444159d5c1", "revision": "1bc93ba022236fcc94092fa40105b96e1d1d2346",
"revisionTime": "2017-02-10T01:55:37Z" "revisionTime": "2017-02-20T09:51:59Z"
}, },
{ {
"checksumSHA1": "1ft/4j5MFa7C9dPI9whL03HSUzk=", "checksumSHA1": "1ft/4j5MFa7C9dPI9whL03HSUzk=",