mirror of
				https://github.com/go-gitea/gitea
				synced 2025-11-03 21:08:25 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			283 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			283 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package pq
 | 
						|
 | 
						|
import (
 | 
						|
	"database/sql/driver"
 | 
						|
	"encoding/binary"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"sync"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	errCopyInClosed               = errors.New("pq: copyin statement has already been closed")
 | 
						|
	errBinaryCopyNotSupported     = errors.New("pq: only text format supported for COPY")
 | 
						|
	errCopyToNotSupported         = errors.New("pq: COPY TO is not supported")
 | 
						|
	errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
 | 
						|
	errCopyInProgress             = errors.New("pq: COPY in progress")
 | 
						|
)
 | 
						|
 | 
						|
// CopyIn creates a COPY FROM statement which can be prepared with
 | 
						|
// Tx.Prepare().  The target table should be visible in search_path.
 | 
						|
func CopyIn(table string, columns ...string) string {
 | 
						|
	stmt := "COPY " + QuoteIdentifier(table) + " ("
 | 
						|
	for i, col := range columns {
 | 
						|
		if i != 0 {
 | 
						|
			stmt += ", "
 | 
						|
		}
 | 
						|
		stmt += QuoteIdentifier(col)
 | 
						|
	}
 | 
						|
	stmt += ") FROM STDIN"
 | 
						|
	return stmt
 | 
						|
}
 | 
						|
 | 
						|
// CopyInSchema creates a COPY FROM statement which can be prepared with
 | 
						|
// Tx.Prepare().
 | 
						|
func CopyInSchema(schema, table string, columns ...string) string {
 | 
						|
	stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
 | 
						|
	for i, col := range columns {
 | 
						|
		if i != 0 {
 | 
						|
			stmt += ", "
 | 
						|
		}
 | 
						|
		stmt += QuoteIdentifier(col)
 | 
						|
	}
 | 
						|
	stmt += ") FROM STDIN"
 | 
						|
	return stmt
 | 
						|
}
 | 
						|
 | 
						|
type copyin struct {
 | 
						|
	cn      *conn
 | 
						|
	buffer  []byte
 | 
						|
	rowData chan []byte
 | 
						|
	done    chan bool
 | 
						|
 | 
						|
	closed bool
 | 
						|
 | 
						|
	sync.Mutex // guards err
 | 
						|
	err        error
 | 
						|
}
 | 
						|
 | 
						|
const ciBufferSize = 64 * 1024
 | 
						|
 | 
						|
// flush buffer before the buffer is filled up and needs reallocation
 | 
						|
const ciBufferFlushSize = 63 * 1024
 | 
						|
 | 
						|
func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
 | 
						|
	if !cn.isInTransaction() {
 | 
						|
		return nil, errCopyNotSupportedOutsideTxn
 | 
						|
	}
 | 
						|
 | 
						|
	ci := ©in{
 | 
						|
		cn:      cn,
 | 
						|
		buffer:  make([]byte, 0, ciBufferSize),
 | 
						|
		rowData: make(chan []byte),
 | 
						|
		done:    make(chan bool, 1),
 | 
						|
	}
 | 
						|
	// add CopyData identifier + 4 bytes for message length
 | 
						|
	ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
 | 
						|
 | 
						|
	b := cn.writeBuf('Q')
 | 
						|
	b.string(q)
 | 
						|
	cn.send(b)
 | 
						|
 | 
						|
awaitCopyInResponse:
 | 
						|
	for {
 | 
						|
		t, r := cn.recv1()
 | 
						|
		switch t {
 | 
						|
		case 'G':
 | 
						|
			if r.byte() != 0 {
 | 
						|
				err = errBinaryCopyNotSupported
 | 
						|
				break awaitCopyInResponse
 | 
						|
			}
 | 
						|
			go ci.resploop()
 | 
						|
			return ci, nil
 | 
						|
		case 'H':
 | 
						|
			err = errCopyToNotSupported
 | 
						|
			break awaitCopyInResponse
 | 
						|
		case 'E':
 | 
						|
			err = parseError(r)
 | 
						|
		case 'Z':
 | 
						|
			if err == nil {
 | 
						|
				ci.setBad()
 | 
						|
				errorf("unexpected ReadyForQuery in response to COPY")
 | 
						|
			}
 | 
						|
			cn.processReadyForQuery(r)
 | 
						|
			return nil, err
 | 
						|
		default:
 | 
						|
			ci.setBad()
 | 
						|
			errorf("unknown response for copy query: %q", t)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// something went wrong, abort COPY before we return
 | 
						|
	b = cn.writeBuf('f')
 | 
						|
	b.string(err.Error())
 | 
						|
	cn.send(b)
 | 
						|
 | 
						|
	for {
 | 
						|
		t, r := cn.recv1()
 | 
						|
		switch t {
 | 
						|
		case 'c', 'C', 'E':
 | 
						|
		case 'Z':
 | 
						|
			// correctly aborted, we're done
 | 
						|
			cn.processReadyForQuery(r)
 | 
						|
			return nil, err
 | 
						|
		default:
 | 
						|
			ci.setBad()
 | 
						|
			errorf("unknown response for CopyFail: %q", t)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) flush(buf []byte) {
 | 
						|
	// set message length (without message identifier)
 | 
						|
	binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
 | 
						|
 | 
						|
	_, err := ci.cn.c.Write(buf)
 | 
						|
	if err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) resploop() {
 | 
						|
	for {
 | 
						|
		var r readBuf
 | 
						|
		t, err := ci.cn.recvMessage(&r)
 | 
						|
		if err != nil {
 | 
						|
			ci.setBad()
 | 
						|
			ci.setError(err)
 | 
						|
			ci.done <- true
 | 
						|
			return
 | 
						|
		}
 | 
						|
		switch t {
 | 
						|
		case 'C':
 | 
						|
			// complete
 | 
						|
		case 'N':
 | 
						|
			// NoticeResponse
 | 
						|
		case 'Z':
 | 
						|
			ci.cn.processReadyForQuery(&r)
 | 
						|
			ci.done <- true
 | 
						|
			return
 | 
						|
		case 'E':
 | 
						|
			err := parseError(&r)
 | 
						|
			ci.setError(err)
 | 
						|
		default:
 | 
						|
			ci.setBad()
 | 
						|
			ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
 | 
						|
			ci.done <- true
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) setBad() {
 | 
						|
	ci.Lock()
 | 
						|
	ci.cn.bad = true
 | 
						|
	ci.Unlock()
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) isBad() bool {
 | 
						|
	ci.Lock()
 | 
						|
	b := ci.cn.bad
 | 
						|
	ci.Unlock()
 | 
						|
	return b
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) isErrorSet() bool {
 | 
						|
	ci.Lock()
 | 
						|
	isSet := (ci.err != nil)
 | 
						|
	ci.Unlock()
 | 
						|
	return isSet
 | 
						|
}
 | 
						|
 | 
						|
// setError() sets ci.err if one has not been set already.  Caller must not be
 | 
						|
// holding ci.Mutex.
 | 
						|
func (ci *copyin) setError(err error) {
 | 
						|
	ci.Lock()
 | 
						|
	if ci.err == nil {
 | 
						|
		ci.err = err
 | 
						|
	}
 | 
						|
	ci.Unlock()
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) NumInput() int {
 | 
						|
	return -1
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
 | 
						|
	return nil, ErrNotSupported
 | 
						|
}
 | 
						|
 | 
						|
// Exec inserts values into the COPY stream. The insert is asynchronous
 | 
						|
// and Exec can return errors from previous Exec calls to the same
 | 
						|
// COPY stmt.
 | 
						|
//
 | 
						|
// You need to call Exec(nil) to sync the COPY stream and to get any
 | 
						|
// errors from pending data, since Stmt.Close() doesn't return errors
 | 
						|
// to the user.
 | 
						|
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
 | 
						|
	if ci.closed {
 | 
						|
		return nil, errCopyInClosed
 | 
						|
	}
 | 
						|
 | 
						|
	if ci.isBad() {
 | 
						|
		return nil, driver.ErrBadConn
 | 
						|
	}
 | 
						|
	defer ci.cn.errRecover(&err)
 | 
						|
 | 
						|
	if ci.isErrorSet() {
 | 
						|
		return nil, ci.err
 | 
						|
	}
 | 
						|
 | 
						|
	if len(v) == 0 {
 | 
						|
		return nil, ci.Close()
 | 
						|
	}
 | 
						|
 | 
						|
	numValues := len(v)
 | 
						|
	for i, value := range v {
 | 
						|
		ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
 | 
						|
		if i < numValues-1 {
 | 
						|
			ci.buffer = append(ci.buffer, '\t')
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	ci.buffer = append(ci.buffer, '\n')
 | 
						|
 | 
						|
	if len(ci.buffer) > ciBufferFlushSize {
 | 
						|
		ci.flush(ci.buffer)
 | 
						|
		// reset buffer, keep bytes for message identifier and length
 | 
						|
		ci.buffer = ci.buffer[:5]
 | 
						|
	}
 | 
						|
 | 
						|
	return driver.RowsAffected(0), nil
 | 
						|
}
 | 
						|
 | 
						|
func (ci *copyin) Close() (err error) {
 | 
						|
	if ci.closed { // Don't do anything, we're already closed
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	ci.closed = true
 | 
						|
 | 
						|
	if ci.isBad() {
 | 
						|
		return driver.ErrBadConn
 | 
						|
	}
 | 
						|
	defer ci.cn.errRecover(&err)
 | 
						|
 | 
						|
	if len(ci.buffer) > 0 {
 | 
						|
		ci.flush(ci.buffer)
 | 
						|
	}
 | 
						|
	// Avoid touching the scratch buffer as resploop could be using it.
 | 
						|
	err = ci.cn.sendSimpleMessage('c')
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	<-ci.done
 | 
						|
	ci.cn.inCopy = false
 | 
						|
 | 
						|
	if ci.isErrorSet() {
 | 
						|
		err = ci.err
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 |