// Copyright 2013 The ql Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSES/QL-LICENSE file. // Copyright 2015 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package evaluator import ( "fmt" "strings" "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" "golang.org/x/text/transform" ) // https://dev.mysql.com/doc/refman/5.7/en/string-functions.html func builtinLength(args []types.Datum, _ context.Context) (d types.Datum, err error) { switch args[0].Kind() { case types.KindNull: d.SetNull() return d, nil default: s, err := args[0].ToString() if err != nil { d.SetNull() return d, errors.Trace(err) } d.SetInt64(int64(len(s))) return d, nil } } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat func builtinConcat(args []types.Datum, _ context.Context) (d types.Datum, err error) { var s []byte for _, a := range args { if a.Kind() == types.KindNull { d.SetNull() return d, nil } var ss string ss, err = a.ToString() if err != nil { d.SetNull() return d, errors.Trace(err) } s = append(s, []byte(ss)...) } d.SetBytesAsString(s) return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat-ws func builtinConcatWS(args []types.Datum, _ context.Context) (d types.Datum, err error) { var sep string s := make([]string, 0, len(args)) for i, a := range args { if a.Kind() == types.KindNull { if i == 0 { d.SetNull() return d, nil } continue } ss, err := a.ToString() if err != nil { d.SetNull() return d, errors.Trace(err) } if i == 0 { sep = ss continue } s = append(s, ss) } d.SetString(strings.Join(s, sep)) return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_left func builtinLeft(args []types.Datum, _ context.Context) (d types.Datum, err error) { str, err := args[0].ToString() if err != nil { d.SetNull() return d, errors.Trace(err) } length, err := args[1].ToInt64() if err != nil { d.SetNull() return d, errors.Trace(err) } l := int(length) if l < 0 { l = 0 } else if l > len(str) { l = len(str) } d.SetString(str[:l]) return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat func builtinRepeat(args []types.Datum, _ context.Context) (d types.Datum, err error) { str, err := args[0].ToString() if err != nil { d.SetNull() return d, err } ch := fmt.Sprintf("%v", str) num := 0 x := args[1] switch x.Kind() { case types.KindInt64: num = int(x.GetInt64()) case types.KindUint64: num = int(x.GetUint64()) } if num < 1 { d.SetString("") return d, nil } d.SetString(strings.Repeat(ch, num)) return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_lower func builtinLower(args []types.Datum, _ context.Context) (d types.Datum, err error) { x := args[0] switch x.Kind() { case types.KindNull: d.SetNull() return d, nil default: s, err := x.ToString() if err != nil { d.SetNull() return d, errors.Trace(err) } d.SetString(strings.ToLower(s)) return d, nil } } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_upper func builtinUpper(args []types.Datum, _ context.Context) (d types.Datum, err error) { x := args[0] switch x.Kind() { case types.KindNull: d.SetNull() return d, nil default: s, err := x.ToString() if err != nil { d.SetNull() return d, errors.Trace(err) } d.SetString(strings.ToUpper(s)) return d, nil } } // See: https://dev.mysql.com/doc/refman/5.7/en/string-comparison-functions.html func builtinStrcmp(args []interface{}, _ context.Context) (interface{}, error) { if args[0] == nil || args[1] == nil { return nil, nil } left, err := types.ToString(args[0]) if err != nil { return nil, errors.Trace(err) } right, err := types.ToString(args[1]) if err != nil { return nil, errors.Trace(err) } res := types.CompareString(left, right) return res, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_replace func builtinReplace(args []interface{}, _ context.Context) (interface{}, error) { for _, arg := range args { if arg == nil { return nil, nil } } str, err := types.ToString(args[0]) if err != nil { return nil, errors.Trace(err) } oldStr, err := types.ToString(args[1]) if err != nil { return nil, errors.Trace(err) } newStr, err := types.ToString(args[2]) if err != nil { return nil, errors.Trace(err) } return strings.Replace(str, oldStr, newStr, -1), nil } // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert func builtinConvert(args []interface{}, _ context.Context) (interface{}, error) { value := args[0] Charset := args[1].(string) // Casting nil to any type returns nil if value == nil { return nil, nil } str, ok := value.(string) if !ok { return nil, nil } if strings.ToLower(Charset) == "ascii" { return value, nil } else if strings.ToLower(Charset) == "utf8mb4" { return value, nil } encoding, _ := charset.Lookup(Charset) if encoding == nil { return nil, errors.Errorf("unknown encoding: %s", Charset) } target, _, err := transform.String(encoding.NewDecoder(), str) if err != nil { log.Errorf("Convert %s to %s with error: %v", str, Charset, err) return nil, errors.Trace(err) } return target, nil } func builtinSubstring(args []interface{}, _ context.Context) (interface{}, error) { // The meaning of the elements of args. // arg[0] -> StrExpr // arg[1] -> Pos // arg[2] -> Len (Optional) str, err := types.ToString(args[0]) if err != nil { return nil, errors.Errorf("Substring invalid args, need string but get %T", args[0]) } t := args[1] p, ok := t.(int64) if !ok { return nil, errors.Errorf("Substring invalid pos args, need int but get %T", t) } pos := int(p) length := -1 if len(args) == 3 { t = args[2] p, ok = t.(int64) if !ok { return nil, errors.Errorf("Substring invalid pos args, need int but get %T", t) } length = int(p) } // The forms without a len argument return a substring from string str starting at position pos. // The forms with a len argument return a substring len characters long from string str, starting at position pos. // The forms that use FROM are standard SQL syntax. It is also possible to use a negative value for pos. // In this case, the beginning of the substring is pos characters from the end of the string, rather than the beginning. // A negative value may be used for pos in any of the forms of this function. if pos < 0 { pos = len(str) + pos } else { pos-- } if pos > len(str) || pos <= 0 { pos = len(str) } end := len(str) if length != -1 { end = pos + length } if end > len(str) { end = len(str) } return str[pos:end], nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring-index func builtinSubstringIndex(args []interface{}, _ context.Context) (interface{}, error) { // The meaning of the elements of args. // args[0] -> StrExpr // args[1] -> Delim // args[2] -> Count fs := args[0] str, err := types.ToString(fs) if err != nil { return nil, errors.Errorf("Substring_Index invalid args, need string but get %T", fs) } t := args[1] delim, err := types.ToString(t) if err != nil { return nil, errors.Errorf("Substring_Index invalid delim, need string but get %T", t) } if len(delim) == 0 { return "", nil } t = args[2] c, err := types.ToInt64(t) if err != nil { return nil, errors.Trace(err) } count := int(c) strs := strings.Split(str, delim) var ( start = 0 end = len(strs) ) if count > 0 { // If count is positive, everything to the left of the final delimiter (counting from the left) is returned. if count < end { end = count } } else { // If count is negative, everything to the right of the final delimiter (counting from the right) is returned. count = -count if count < end { start = end - count } } substrs := strs[start:end] return strings.Join(substrs, delim), nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate func builtinLocate(args []interface{}, _ context.Context) (interface{}, error) { // The meaning of the elements of args. // args[0] -> SubStr // args[1] -> Str // args[2] -> Pos // eval str fs := args[1] if fs == nil { return nil, nil } str, err := types.ToString(fs) if err != nil { return nil, errors.Trace(err) } // eval substr fs = args[0] if fs == nil { return nil, nil } subStr, err := types.ToString(fs) if err != nil { return nil, errors.Trace(err) } // eval pos pos := int64(0) if len(args) == 3 { t := args[2] p, err := types.ToInt64(t) if err != nil { return nil, errors.Trace(err) } pos = p - 1 if pos < 0 || pos > int64(len(str)) { return 0, nil } if pos > int64(len(str)-len(subStr)) { return 0, nil } } if len(subStr) == 0 { return pos + 1, nil } i := strings.Index(str[pos:], subStr) if i == -1 { return 0, nil } return int64(i) + pos + 1, nil } const spaceChars = "\n\t\r " // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim func builtinTrim(args []interface{}, _ context.Context) (interface{}, error) { // args[0] -> Str // args[1] -> RemStr // args[2] -> Direction // eval str fs := args[0] if fs == nil { return nil, nil } str, err := types.ToString(fs) if err != nil { return nil, errors.Trace(err) } remstr := "" // eval remstr if len(args) > 1 { fs = args[1] if fs != nil { remstr, err = types.ToString(fs) if err != nil { return nil, errors.Trace(err) } } } // do trim var result string var direction ast.TrimDirectionType if len(args) > 2 { direction = args[2].(ast.TrimDirectionType) } else { direction = ast.TrimBothDefault } if direction == ast.TrimLeading { if len(remstr) > 0 { result = trimLeft(str, remstr) } else { result = strings.TrimLeft(str, spaceChars) } } else if direction == ast.TrimTrailing { if len(remstr) > 0 { result = trimRight(str, remstr) } else { result = strings.TrimRight(str, spaceChars) } } else if len(remstr) > 0 { x := trimLeft(str, remstr) result = trimRight(x, remstr) } else { result = strings.Trim(str, spaceChars) } return result, nil } func trimLeft(str, remstr string) string { for { x := strings.TrimPrefix(str, remstr) if len(x) == len(str) { return x } str = x } } func trimRight(str, remstr string) string { for { x := strings.TrimSuffix(str, remstr) if len(x) == len(str) { return x } str = x } }