1
1
mirror of https://github.com/go-gitea/gitea synced 2025-01-10 17:54:27 +00:00

805 lines
19 KiB
Go
Raw Normal View History

package mssql
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
)
//go:generate stringer -type token
type token byte
// token ids
const (
tokenReturnStatus token = 121 // 0x79
tokenColMetadata token = 129 // 0x81
tokenOrder token = 169 // 0xA9
tokenError token = 170 // 0xAA
tokenInfo token = 171 // 0xAB
tokenReturnValue token = 0xAC
tokenLoginAck token = 173 // 0xad
tokenRow token = 209 // 0xd1
tokenNbcRow token = 210 // 0xd2
tokenEnvChange token = 227 // 0xE3
tokenSSPI token = 237 // 0xED
tokenDone token = 253 // 0xFD
tokenDoneProc token = 254
tokenDoneInProc token = 255
)
// done flags
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
const (
doneFinal = 0
doneMore = 1
doneError = 2
doneInxact = 4
doneCount = 0x10
doneAttn = 0x20
doneSrvError = 0x100
)
// ENVCHANGE types
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
const (
envTypDatabase = 1
envTypLanguage = 2
envTypCharset = 3
envTypPacketSize = 4
envSortId = 5
envSortFlags = 6
envSqlCollation = 7
envTypBeginTran = 8
envTypCommitTran = 9
envTypRollbackTran = 10
envEnlistDTC = 11
envDefectTran = 12
envDatabaseMirrorPartner = 13
envPromoteTran = 15
envTranMgrAddr = 16
envTranEnded = 17
envResetConnAck = 18
envStartedInstanceName = 19
envRouting = 20
)
// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
colFlagNullable = 1
// TODO implement more flags
)
// interface for all tokens
type tokenStruct interface{}
type orderStruct struct {
ColIds []uint16
}
type doneStruct struct {
Status uint16
CurCmd uint16
RowCount uint64
errors []Error
}
func (d doneStruct) isError() bool {
return d.Status&doneError != 0 || len(d.errors) > 0
}
func (d doneStruct) getError() Error {
if len(d.errors) > 0 {
return d.errors[len(d.errors)-1]
} else {
return Error{Message: "Request failed but didn't provide reason"}
}
}
type doneInProcStruct doneStruct
var doneFlags2str = map[uint16]string{
doneFinal: "final",
doneMore: "more",
doneError: "error",
doneInxact: "inxact",
doneCount: "count",
doneAttn: "attn",
doneSrvError: "srverror",
}
func doneFlags2Str(flags uint16) string {
strs := make([]string, 0, len(doneFlags2str))
for flag, tag := range doneFlags2str {
if flags&flag != 0 {
strs = append(strs, tag)
}
}
return strings.Join(strs, "|")
}
// ENVCHANGE stream
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
func processEnvChg(sess *tdsSession) {
size := sess.buf.uint16()
r := &io.LimitedReader{R: sess.buf, N: int64(size)}
for {
var err error
var envtype uint8
err = binary.Read(r, binary.LittleEndian, &envtype)
if err == io.EOF {
return
}
if err != nil {
badStreamPanic(err)
}
switch envtype {
case envTypDatabase:
sess.database, err = readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
_, err = readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
case envTypLanguage:
// currently ignored
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypCharset:
// currently ignored
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypPacketSize:
packetsize, err := readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
_, err = readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
packetsizei, err := strconv.Atoi(packetsize)
if err != nil {
badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error())
}
sess.buf.ResizeBuffer(packetsizei)
case envSortId:
// currently ignored
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envSortFlags:
// currently ignored
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envSqlCollation:
// currently ignored
var collationSize uint8
err = binary.Read(r, binary.LittleEndian, &collationSize)
if err != nil {
badStreamPanic(err)
}
// SQL Collation data should contain 5 bytes in length
if collationSize != 5 {
badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize)
}
// 4 bytes, contains: LCID ColFlags Version
var info uint32
err = binary.Read(r, binary.LittleEndian, &info)
if err != nil {
badStreamPanic(err)
}
// 1 byte, contains: sortID
var sortID uint8
err = binary.Read(r, binary.LittleEndian, &sortID)
if err != nil {
badStreamPanic(err)
}
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypBeginTran:
tranid, err := readBVarByte(r)
if len(tranid) != 8 {
badStreamPanicf("invalid size of transaction identifier: %d", len(tranid))
}
sess.tranid = binary.LittleEndian.Uint64(tranid)
if err != nil {
badStreamPanic(err)
}
if sess.logFlags&logTransaction != 0 {
sess.log.Printf("BEGIN TRANSACTION %x\n", sess.tranid)
}
_, err = readBVarByte(r)
if err != nil {
badStreamPanic(err)
}
case envTypCommitTran, envTypRollbackTran:
_, err = readBVarByte(r)
if err != nil {
badStreamPanic(err)
}
_, err = readBVarByte(r)
if err != nil {
badStreamPanic(err)
}
if sess.logFlags&logTransaction != 0 {
if envtype == envTypCommitTran {
sess.log.Printf("COMMIT TRANSACTION %x\n", sess.tranid)
} else {
sess.log.Printf("ROLLBACK TRANSACTION %x\n", sess.tranid)
}
}
sess.tranid = 0
case envEnlistDTC:
// currently ignored
// new value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envDefectTran:
// currently ignored
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envDatabaseMirrorPartner:
sess.partner, err = readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
_, err = readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
case envPromoteTran:
// currently ignored
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// dtc token
// spec says it should be L_VARBYTE, so this code might be wrong
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTranMgrAddr:
// currently ignored
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// XACT_MANAGER_ADDRESS = B_VARBYTE
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTranEnded:
// currently ignored
// old value, B_VARBYTE
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envResetConnAck:
// currently ignored
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envStartedInstanceName:
// currently ignored
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// instance name
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envRouting:
// RoutingData message is:
// ValueLength USHORT
// Protocol (TCP = 0) BYTE
// ProtocolProperty (new port) USHORT
// AlternateServer US_VARCHAR
_, err := readUshort(r)
if err != nil {
badStreamPanic(err)
}
protocol, err := readByte(r)
if err != nil || protocol != 0 {
badStreamPanic(err)
}
newPort, err := readUshort(r)
if err != nil {
badStreamPanic(err)
}
newServer, err := readUsVarChar(r)
if err != nil {
badStreamPanic(err)
}
// consume the OLDVALUE = %x00 %x00
_, err = readUshort(r)
if err != nil {
badStreamPanic(err)
}
sess.routedServer = newServer
sess.routedPort = newPort
default:
// ignore rest of records because we don't know how to skip those
sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype)
break
}
}
}
// http://msdn.microsoft.com/en-us/library/dd358180.aspx
func parseReturnStatus(r *tdsBuffer) ReturnStatus {
return ReturnStatus(r.int32())
}
func parseOrder(r *tdsBuffer) (res orderStruct) {
len := int(r.uint16())
res.ColIds = make([]uint16, len/2)
for i := 0; i < len/2; i++ {
res.ColIds[i] = r.uint16()
}
return res
}
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
func parseDone(r *tdsBuffer) (res doneStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
res.RowCount = r.uint64()
return res
}
// https://msdn.microsoft.com/en-us/library/dd340553.aspx
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
res.RowCount = r.uint64()
return res
}
type sspiMsg []byte
func parseSSPIMsg(r *tdsBuffer) sspiMsg {
size := r.uint16()
buf := make([]byte, size)
r.ReadFull(buf)
return sspiMsg(buf)
}
type loginAckStruct struct {
Interface uint8
TDSVersion uint32
ProgName string
ProgVer uint32
}
func parseLoginAck(r *tdsBuffer) loginAckStruct {
size := r.uint16()
buf := make([]byte, size)
r.ReadFull(buf)
var res loginAckStruct
res.Interface = buf[0]
res.TDSVersion = binary.BigEndian.Uint32(buf[1:])
prognamelen := buf[1+4]
var err error
if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil {
badStreamPanic(err)
}
res.ProgVer = binary.BigEndian.Uint32(buf[size-4:])
return res
}
// http://msdn.microsoft.com/en-us/library/dd357363.aspx
func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) {
count := r.uint16()
if count == 0xffff {
// no metadata is sent
return nil
}
columns = make([]columnStruct, count)
for i := range columns {
column := &columns[i]
column.UserType = r.uint32()
column.Flags = r.uint16()
// parsing TYPE_INFO structure
column.ti = readTypeInfo(r)
column.ColName = r.BVarChar()
}
return columns
}
// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
for i, column := range columns {
row[i] = column.ti.Reader(&column.ti, r)
}
}
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
bitlen := (len(columns) + 7) / 8
pres := make([]byte, bitlen)
r.ReadFull(pres)
for i, col := range columns {
if pres[i/8]&(1<<(uint(i)%8)) != 0 {
row[i] = nil
continue
}
row[i] = col.ti.Reader(&col.ti, r)
}
}
// http://msdn.microsoft.com/en-us/library/dd304156.aspx
func parseError72(r *tdsBuffer) (res Error) {
length := r.uint16()
_ = length // ignore length
res.Number = r.int32()
res.State = r.byte()
res.Class = r.byte()
res.Message = r.UsVarChar()
res.ServerName = r.BVarChar()
res.ProcName = r.BVarChar()
res.LineNo = r.int32()
return
}
// http://msdn.microsoft.com/en-us/library/dd304156.aspx
func parseInfo(r *tdsBuffer) (res Error) {
length := r.uint16()
_ = length // ignore length
res.Number = r.int32()
res.State = r.byte()
res.Class = r.byte()
res.Message = r.UsVarChar()
res.ServerName = r.BVarChar()
res.ProcName = r.BVarChar()
res.LineNo = r.int32()
return
}
// https://msdn.microsoft.com/en-us/library/dd303881.aspx
func parseReturnValue(r *tdsBuffer) (nv namedValue) {
/*
ParamOrdinal
ParamName
Status
UserType
Flags
TypeInfo
CryptoMetadata
Value
*/
r.uint16()
nv.Name = r.BVarChar()
r.byte()
r.uint32() // UserType (uint16 prior to 7.2)
r.uint16()
ti := readTypeInfo(r)
nv.Value = ti.Reader(&ti, r)
return
}
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Printf("ERROR: Intercepted panic %v", err)
}
ch <- err
}
close(ch)
}()
packet_type, err := sess.buf.BeginRead()
if err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Printf("ERROR: BeginRead failed %v", err)
}
ch <- err
return
}
if packet_type != packReply {
badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply))
}
var columns []columnStruct
errs := make([]Error, 0, 5)
for {
token := token(sess.buf.byte())
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got token %v", token)
}
switch token {
case tokenSSPI:
ch <- parseSSPIMsg(sess.buf)
return
case tokenReturnStatus:
returnStatus := parseReturnStatus(sess.buf)
ch <- returnStatus
case tokenLoginAck:
loginAck := parseLoginAck(sess.buf)
ch <- loginAck
case tokenOrder:
order := parseOrder(sess.buf)
ch <- order
case tokenDoneInProc:
done := parseDoneInProc(sess.buf)
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
}
ch <- done
case tokenDone, tokenDoneProc:
done := parseDone(sess.buf)
done.errors = errs
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
}
if done.Status&doneSrvError != 0 {
ch <- errors.New("SQL Server had internal error")
return
}
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
}
ch <- done
if done.Status&doneMore == 0 {
return
}
case tokenColMetadata:
columns = parseColMetadata72(sess.buf)
ch <- columns
case tokenRow:
row := make([]interface{}, len(columns))
parseRow(sess.buf, columns, row)
ch <- row
case tokenNbcRow:
row := make([]interface{}, len(columns))
parseNbcRow(sess.buf, columns, row)
ch <- row
case tokenEnvChange:
processEnvChg(sess)
case tokenError:
err := parseError72(sess.buf)
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got ERROR %d %s", err.Number, err.Message)
}
errs = append(errs, err)
if sess.logFlags&logErrors != 0 {
sess.log.Println(err.Message)
}
case tokenInfo:
info := parseInfo(sess.buf)
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got INFO %d %s", info.Number, info.Message)
}
if sess.logFlags&logMessages != 0 {
sess.log.Println(info.Message)
}
case tokenReturnValue:
nv := parseReturnValue(sess.buf)
if len(nv.Name) > 0 {
name := nv.Name[1:] // Remove the leading "@".
if ov, has := outs[name]; has {
err = scanIntoOut(name, nv.Value, ov)
if err != nil {
fmt.Println("scan error", err)
ch <- err
}
}
}
default:
badStreamPanic(fmt.Errorf("unknown token type returned: %v", token))
}
}
}
type parseRespIter byte
const (
parseRespIterContinue parseRespIter = iota // Continue parsing current token.
parseRespIterNext // Fetch the next token.
parseRespIterDone // Done with parsing the response.
)
type parseRespState byte
const (
parseRespStateNormal parseRespState = iota // Normal response state.
parseRespStateCancel // Query is canceled, wait for server to confirm.
parseRespStateClosing // Waiting for tokens to come through.
)
type parseResp struct {
sess *tdsSession
ctxDone <-chan struct{}
state parseRespState
cancelError error
}
func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
if err := sendAttention(ts.sess.buf); err != nil {
ts.dlogf("failed to send attention signal %v", err)
ch <- err
return parseRespIterDone
}
ts.state = parseRespStateCancel
return parseRespIterContinue
}
func (ts *parseResp) dlog(msg string) {
if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Println(msg)
}
}
func (ts *parseResp) dlogf(f string, v ...interface{}) {
if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Printf(f, v...)
}
}
func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
switch ts.state {
default:
panic("unknown state")
case parseRespStateNormal:
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished")
return parseRespIterDone
}
if err, ok := tok.(net.Error); ok && err.Timeout() {
ts.cancelError = err
ts.dlog("got timeout error, sending attention signal to server")
return ts.sendAttention(ch)
}
// Pass the token along.
ch <- tok
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.dlog("got cancel message, sending attention signal to server")
return ts.sendAttention(ch)
}
case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished but waiting for attention ack")
return parseRespIterNext
}
switch tok := tok.(type) {
default:
// Ignore all other tokens while waiting.
// The TDS spec says other tokens may arrive after an attention
// signal is sent. Ignore these tokens and continue looking for
// a DONE with attention confirm mark.
case doneStruct:
if tok.Status&doneAttn != 0 {
ts.dlog("got cancellation confirmation from server")
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
} else {
ch <- ctx.Err()
}
return parseRespIterDone
}
// If an error happens during cancel, pass it along and just stop.
// We are uncertain to receive more tokens.
case error:
ch <- tok
ts.state = parseRespStateClosing
}
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.state = parseRespStateClosing
return parseRespIterContinue
}
case parseRespStateClosing: // Wait for current token chan to close.
if _, ok := <-tokChan; !ok {
ts.dlog("response finished")
return parseRespIterDone
}
return parseRespIterContinue
}
}
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
ts := &parseResp{
sess: sess,
ctxDone: ctx.Done(),
}
defer func() {
// Ensure any remaining error is piped through
// or the query may look like it executed when it actually failed.
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
}
close(ch)
}()
// Loop over multiple responses.
for {
ts.dlog("initiating response reading")
tokChan := make(chan tokenStruct)
go processSingleResponse(sess, tokChan, outs)
// Loop over multiple tokens in response.
tokensLoop:
for {
switch ts.iter(ctx, ch, tokChan) {
case parseRespIterContinue:
// Nothing, continue to next token.
case parseRespIterNext:
break tokensLoop
case parseRespIterDone:
return
}
}
}
}