1
1
mirror of https://github.com/go-gitea/gitea synced 2025-07-22 18:28:37 +00:00

Use vendored go-swagger (#8087)

* Use vendored go-swagger

* vendor go-swagger

* revert un wanteed change

* remove un-needed GO111MODULE

* Update Makefile

Co-Authored-By: techknowlogick <matti@mdranta.net>
This commit is contained in:
Antoine GIRARD
2019-09-04 21:53:54 +02:00
committed by Lauris BH
parent 4cb1bdddc8
commit 9fe4437bda
686 changed files with 143379 additions and 17 deletions

View File

@@ -0,0 +1,389 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// Copier is a type that allows copying between ValueReaders, ValueWriters, and
// []byte values.
type Copier struct{}
// NewCopier creates a new copier with the given registry. If a nil registry is provided
// a default registry is used.
func NewCopier() Copier {
return Copier{}
}
// CopyDocument handles copying a document from src to dst.
func CopyDocument(dst ValueWriter, src ValueReader) error {
return Copier{}.CopyDocument(dst, src)
}
// CopyDocument handles copying one document from the src to the dst.
func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error {
dr, err := src.ReadDocument()
if err != nil {
return err
}
dw, err := dst.WriteDocument()
if err != nil {
return err
}
return c.copyDocumentCore(dw, dr)
}
// CopyDocumentFromBytes copies the values from a BSON document represented as a
// []byte to a ValueWriter.
func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error {
dw, err := dst.WriteDocument()
if err != nil {
return err
}
err = c.CopyBytesToDocumentWriter(dw, src)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
// DocumentWriter.
func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
// TODO(skriptble): Create errors types here. Anything thats a tag should be a property.
length, rem, ok := bsoncore.ReadLength(src)
if !ok {
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
}
if len(src) < int(length) {
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
}
rem = rem[:length-4]
var t bsontype.Type
var key string
var val bsoncore.Value
for {
t, rem, ok = bsoncore.ReadType(rem)
if !ok {
return io.EOF
}
if t == bsontype.Type(0) {
if len(rem) != 0 {
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
}
break
}
key, rem, ok = bsoncore.ReadKey(rem)
if !ok {
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
}
dvw, err := dst.WriteDocumentElement(key)
if err != nil {
return err
}
val, rem, ok = bsoncore.ReadValue(rem, t)
if !ok {
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
}
err = c.CopyValueFromBytes(dvw, t, val.Data)
if err != nil {
return err
}
}
return nil
}
// CopyDocumentToBytes copies an entire document from the ValueReader and
// returns it as bytes.
func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) {
return c.AppendDocumentBytes(nil, src)
}
// AppendDocumentBytes functions the same as CopyDocumentToBytes, but will
// append the result to dst.
func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(BytesReader); ok {
_, dst, err := br.ReadValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
vw.reset(dst)
err := c.CopyDocument(vw, src)
dst = vw.buf
return dst, err
}
// CopyValueFromBytes will write the value represtend by t and src to dst.
func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error {
if wvb, ok := dst.(BytesWriter); ok {
return wvb.WriteValueBytes(t, src)
}
vr := vrPool.Get().(*valueReader)
defer vrPool.Put(vr)
vr.reset(src)
vr.pushElement(t)
return c.CopyValue(dst, vr)
}
// CopyValueToBytes copies a value from src and returns it as a bsontype.Type and a
// []byte.
func (c Copier) CopyValueToBytes(src ValueReader) (bsontype.Type, []byte, error) {
return c.AppendValueBytes(nil, src)
}
// AppendValueBytes functions the same as CopyValueToBytes, but will append the
// result to dst.
func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []byte, error) {
if br, ok := src.(BytesReader); ok {
return br.ReadValueBytes(dst)
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
start := len(dst)
vw.reset(dst)
vw.push(mElement)
err := c.CopyValue(vw, src)
if err != nil {
return 0, dst, err
}
return bsontype.Type(vw.buf[start]), vw.buf[start+2:], nil
}
// CopyValue will copy a single value from src to dst.
func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error {
var err error
switch src.Type() {
case bsontype.Double:
var f64 float64
f64, err = src.ReadDouble()
if err != nil {
break
}
err = dst.WriteDouble(f64)
case bsontype.String:
var str string
str, err = src.ReadString()
if err != nil {
return err
}
err = dst.WriteString(str)
case bsontype.EmbeddedDocument:
err = c.CopyDocument(dst, src)
case bsontype.Array:
err = c.copyArray(dst, src)
case bsontype.Binary:
var data []byte
var subtype byte
data, subtype, err = src.ReadBinary()
if err != nil {
break
}
err = dst.WriteBinaryWithSubtype(data, subtype)
case bsontype.Undefined:
err = src.ReadUndefined()
if err != nil {
break
}
err = dst.WriteUndefined()
case bsontype.ObjectID:
var oid primitive.ObjectID
oid, err = src.ReadObjectID()
if err != nil {
break
}
err = dst.WriteObjectID(oid)
case bsontype.Boolean:
var b bool
b, err = src.ReadBoolean()
if err != nil {
break
}
err = dst.WriteBoolean(b)
case bsontype.DateTime:
var dt int64
dt, err = src.ReadDateTime()
if err != nil {
break
}
err = dst.WriteDateTime(dt)
case bsontype.Null:
err = src.ReadNull()
if err != nil {
break
}
err = dst.WriteNull()
case bsontype.Regex:
var pattern, options string
pattern, options, err = src.ReadRegex()
if err != nil {
break
}
err = dst.WriteRegex(pattern, options)
case bsontype.DBPointer:
var ns string
var pointer primitive.ObjectID
ns, pointer, err = src.ReadDBPointer()
if err != nil {
break
}
err = dst.WriteDBPointer(ns, pointer)
case bsontype.JavaScript:
var js string
js, err = src.ReadJavascript()
if err != nil {
break
}
err = dst.WriteJavascript(js)
case bsontype.Symbol:
var symbol string
symbol, err = src.ReadSymbol()
if err != nil {
break
}
err = dst.WriteSymbol(symbol)
case bsontype.CodeWithScope:
var code string
var srcScope DocumentReader
code, srcScope, err = src.ReadCodeWithScope()
if err != nil {
break
}
var dstScope DocumentWriter
dstScope, err = dst.WriteCodeWithScope(code)
if err != nil {
break
}
err = c.copyDocumentCore(dstScope, srcScope)
case bsontype.Int32:
var i32 int32
i32, err = src.ReadInt32()
if err != nil {
break
}
err = dst.WriteInt32(i32)
case bsontype.Timestamp:
var t, i uint32
t, i, err = src.ReadTimestamp()
if err != nil {
break
}
err = dst.WriteTimestamp(t, i)
case bsontype.Int64:
var i64 int64
i64, err = src.ReadInt64()
if err != nil {
break
}
err = dst.WriteInt64(i64)
case bsontype.Decimal128:
var d128 primitive.Decimal128
d128, err = src.ReadDecimal128()
if err != nil {
break
}
err = dst.WriteDecimal128(d128)
case bsontype.MinKey:
err = src.ReadMinKey()
if err != nil {
break
}
err = dst.WriteMinKey()
case bsontype.MaxKey:
err = src.ReadMaxKey()
if err != nil {
break
}
err = dst.WriteMaxKey()
default:
err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type())
}
return err
}
func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {
ar, err := src.ReadArray()
if err != nil {
return err
}
aw, err := dst.WriteArray()
if err != nil {
return err
}
for {
vr, err := ar.ReadValue()
if err == ErrEOA {
break
}
if err != nil {
return err
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = c.CopyValue(vw, vr)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if err == ErrEOD {
break
}
if err != nil {
return err
}
vw, err := dw.WriteDocumentElement(key)
if err != nil {
return err
}
err = c.CopyValue(vw, vr)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}

View File

@@ -0,0 +1,9 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsonrw contains abstractions for reading and writing
// BSON and BSON like types from sources.
package bsonrw // import "go.mongodb.org/mongo-driver/bson/bsonrw"

View File

@@ -0,0 +1,731 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"errors"
"fmt"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
const maxNestingDepth = 200
// ErrInvalidJSON indicates the JSON input is invalid
var ErrInvalidJSON = errors.New("invalid JSON input")
type jsonParseState byte
const (
jpsStartState jsonParseState = iota
jpsSawBeginObject
jpsSawEndObject
jpsSawBeginArray
jpsSawEndArray
jpsSawColon
jpsSawComma
jpsSawKey
jpsSawValue
jpsDoneState
jpsInvalidState
)
type jsonParseMode byte
const (
jpmInvalidMode jsonParseMode = iota
jpmObjectMode
jpmArrayMode
)
type extJSONValue struct {
t bsontype.Type
v interface{}
}
type extJSONObject struct {
keys []string
values []*extJSONValue
}
type extJSONParser struct {
js *jsonScanner
s jsonParseState
m []jsonParseMode
k string
v *extJSONValue
err error
canonical bool
depth int
maxDepth int
emptyObject bool
}
// newExtJSONParser returns a new extended JSON parser, ready to to begin
// parsing from the first character of the argued json input. It will not
// perform any read-ahead and will therefore not report any errors about
// malformed JSON at this point.
func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser {
return &extJSONParser{
js: &jsonScanner{r: r},
s: jpsStartState,
m: []jsonParseMode{},
canonical: canonical,
maxDepth: maxNestingDepth,
}
}
// peekType examines the next value and returns its BSON Type
func (ejp *extJSONParser) peekType() (bsontype.Type, error) {
var t bsontype.Type
var err error
ejp.advanceState()
switch ejp.s {
case jpsSawValue:
t = ejp.v.t
case jpsSawBeginArray:
t = bsontype.Array
case jpsInvalidState:
err = ejp.err
case jpsSawComma:
// in array mode, seeing a comma means we need to progress again to actually observe a type
if ejp.peekMode() == jpmArrayMode {
return ejp.peekType()
}
case jpsSawEndArray:
// this would only be a valid state if we were in array mode, so return end-of-array error
err = ErrEOA
case jpsSawBeginObject:
// peek key to determine type
ejp.advanceState()
switch ejp.s {
case jpsSawEndObject: // empty embedded document
t = bsontype.EmbeddedDocument
ejp.emptyObject = true
case jpsInvalidState:
err = ejp.err
case jpsSawKey:
t = wrapperKeyBSONType(ejp.k)
if t == bsontype.JavaScript {
// just saw $code, need to check for $scope at same level
_, err := ejp.readValue(bsontype.JavaScript)
if err != nil {
break
}
switch ejp.s {
case jpsSawEndObject: // type is TypeJavaScript
case jpsSawComma:
ejp.advanceState()
if ejp.s == jpsSawKey && ejp.k == "$scope" {
t = bsontype.CodeWithScope
} else {
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
}
case jpsInvalidState:
err = ejp.err
default:
err = ErrInvalidJSON
}
}
}
}
return t, err
}
// readKey parses the next key and its type and returns them
func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) {
if ejp.emptyObject {
ejp.emptyObject = false
return "", 0, ErrEOD
}
// advance to key (or return with error)
switch ejp.s {
case jpsStartState:
ejp.advanceState()
if ejp.s == jpsSawBeginObject {
ejp.advanceState()
}
case jpsSawBeginObject:
ejp.advanceState()
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject, jpsSawComma:
ejp.advanceState()
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsDoneState:
return "", 0, io.EOF
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, ErrInvalidJSON
}
case jpsSawKey: // do nothing (key was peeked before)
default:
return "", 0, invalidRequestError("key")
}
// read key
var key string
switch ejp.s {
case jpsSawKey:
key = ejp.k
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, invalidRequestError("key")
}
// check for colon
ejp.advanceState()
if err := ensureColon(ejp.s, key); err != nil {
return "", 0, err
}
// peek at the value to determine type
t, err := ejp.peekType()
if err != nil {
return "", 0, err
}
return key, t, nil
}
// readValue returns the value corresponding to the Type returned by peekType
func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
if ejp.s == jpsInvalidState {
return nil, ejp.err
}
var v *extJSONValue
switch t {
case bsontype.Null, bsontype.Boolean, bsontype.String:
if ejp.s != jpsSawValue {
return nil, invalidRequestError(t.String())
}
v = ejp.v
case bsontype.Int32, bsontype.Int64, bsontype.Double:
// relaxed version allows these to be literal number values
if ejp.s == jpsSawValue {
v = ejp.v
break
}
fallthrough
case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("} after value", t)
}
default:
return nil, invalidRequestError(t.String())
}
case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer:
if ejp.s != jpsSawKey {
return nil, invalidRequestError(t.String())
}
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
if t == bsontype.Binary && ejp.s == jpsSawValue {
// convert legacy $binary format
base64 := ejp.v
ejp.advanceState()
if ejp.s != jpsSawComma {
return nil, invalidJSONErrorForType(",", bsontype.Binary)
}
ejp.advanceState()
key, t, err := ejp.readKey()
if err != nil {
return nil, err
}
if key != "$type" {
return nil, invalidJSONErrorForType("$type", bsontype.Binary)
}
subType, err := ejp.readValue(t)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary)
}
v = &extJSONValue{
t: bsontype.EmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// read KV pairs
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONErrorForType("{", t)
}
keys, vals, err := ejp.readObject(2, true)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
}
v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case bsontype.DateTime:
switch ejp.s {
case jpsSawValue:
v = ejp.v
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject:
keys, vals, err := ejp.readObject(1, true)
if err != nil {
return nil, err
}
v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case jpsSawValue:
if ejp.canonical {
return nil, invalidJSONError("{")
}
v = ejp.v
default:
if ejp.canonical {
return nil, invalidJSONErrorForType("object", t)
}
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as decribed in RFC-3339", t)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("value and then }", t)
}
default:
return nil, invalidRequestError(t.String())
}
case bsontype.JavaScript:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object or comma and just return
ejp.advanceState()
case jpsSawEndObject:
v = ejp.v
default:
return nil, invalidRequestError(t.String())
}
case bsontype.CodeWithScope:
if ejp.s == jpsSawKey && ejp.k == "$scope" {
v = ejp.v // this is the $code string from earlier
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONError("$scope to be embedded document")
}
} else {
return nil, invalidRequestError(t.String())
}
case bsontype.EmbeddedDocument, bsontype.Array:
return nil, invalidRequestError(t.String())
}
return v, nil
}
// readObject is a utility method for reading full objects of known (or expected) size
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
keys := make([]string, numKeys)
vals := make([]*extJSONValue, numKeys)
if !started {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, nil, invalidJSONError("{")
}
}
for i := 0; i < numKeys; i++ {
key, t, err := ejp.readKey()
if err != nil {
return nil, nil, err
}
switch ejp.s {
case jpsSawKey:
v, err := ejp.readValue(t)
if err != nil {
return nil, nil, err
}
keys[i] = key
vals[i] = v
case jpsSawValue:
keys[i] = key
vals[i] = ejp.v
default:
return nil, nil, invalidJSONError("value")
}
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, nil, invalidJSONError("}")
}
return keys, vals, nil
}
// advanceState reads the next JSON token from the scanner and transitions
// from the current state based on that token's type
func (ejp *extJSONParser) advanceState() {
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
return
}
jt, err := ejp.js.nextToken()
if err != nil {
ejp.err = err
ejp.s = jpsInvalidState
return
}
valid := ejp.validateToken(jt.t)
if !valid {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
return
}
switch jt.t {
case jttBeginObject:
ejp.s = jpsSawBeginObject
ejp.pushMode(jpmObjectMode)
ejp.depth++
if ejp.depth > ejp.maxDepth {
ejp.err = nestingDepthError(jt.p, ejp.depth)
ejp.s = jpsInvalidState
}
case jttEndObject:
ejp.s = jpsSawEndObject
ejp.depth--
if ejp.popMode() != jpmObjectMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttBeginArray:
ejp.s = jpsSawBeginArray
ejp.pushMode(jpmArrayMode)
case jttEndArray:
ejp.s = jpsSawEndArray
if ejp.popMode() != jpmArrayMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttColon:
ejp.s = jpsSawColon
case jttComma:
ejp.s = jpsSawComma
case jttEOF:
ejp.s = jpsDoneState
if len(ejp.m) != 0 {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttString:
switch ejp.s {
case jpsSawComma:
if ejp.peekMode() == jpmArrayMode {
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
return
}
fallthrough
case jpsSawBeginObject:
ejp.s = jpsSawKey
ejp.k = jt.v.(string)
return
}
fallthrough
default:
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
}
}
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
jpsStartState: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
jttEOF: true,
},
jpsSawBeginObject: {
jttEndObject: true,
jttString: true,
},
jpsSawEndObject: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawBeginArray: {
jttBeginObject: true,
jttBeginArray: true,
jttEndArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawEndArray: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawColon: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawComma: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawKey: {
jttColon: true,
},
jpsSawValue: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsDoneState: {},
jpsInvalidState: {},
}
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
switch ejp.s {
case jpsSawEndObject:
// if we are at depth zero and the next token is a '{',
// we can consider it valid only if we are not in array mode.
if jtt == jttBeginObject && ejp.depth == 0 {
return ejp.peekMode() != jpmArrayMode
}
case jpsSawComma:
switch ejp.peekMode() {
// the only valid next token after a comma inside a document is a string (a key)
case jpmObjectMode:
return jtt == jttString
case jpmInvalidMode:
return false
}
}
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
return ok
}
// ensureExtValueType returns true if the current value has the expected
// value type for single-key extended JSON types. For example,
// {"$numberInt": v} v must be TypeString
func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool {
switch t {
case bsontype.MinKey, bsontype.MaxKey:
return ejp.v.t == bsontype.Int32
case bsontype.Undefined:
return ejp.v.t == bsontype.Boolean
case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID:
return ejp.v.t == bsontype.String
default:
return false
}
}
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
ejp.m = append(ejp.m, m)
}
func (ejp *extJSONParser) popMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
m := ejp.m[l-1]
ejp.m = ejp.m[:l-1]
return m
}
func (ejp *extJSONParser) peekMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
return ejp.m[l-1]
}
func extendJSONToken(jt *jsonToken) *extJSONValue {
var t bsontype.Type
switch jt.t {
case jttInt32:
t = bsontype.Int32
case jttInt64:
t = bsontype.Int64
case jttDouble:
t = bsontype.Double
case jttString:
t = bsontype.String
case jttBool:
t = bsontype.Boolean
case jttNull:
t = bsontype.Null
default:
return nil
}
return &extJSONValue{t: t, v: jt.v}
}
func ensureColon(s jsonParseState, key string) error {
if s != jpsSawColon {
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
}
return nil
}
func invalidRequestError(s string) error {
return fmt.Errorf("invalid request to read %s", s)
}
func invalidJSONError(expected string) error {
return fmt.Errorf("invalid JSON input; expected %s", expected)
}
func invalidJSONErrorForType(expected string, t bsontype.Type) error {
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
}
func unexpectedTokenError(jt *jsonToken) error {
switch jt.t {
case jttInt32, jttInt64, jttDouble:
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
case jttString:
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
case jttBool:
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
case jttNull:
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
case jttEOF:
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
default:
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
}
}
func nestingDepthError(p, depth int) error {
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
}

View File

@@ -0,0 +1,659 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ExtJSONValueReaderPool is a pool for ValueReaders that read ExtJSON.
type ExtJSONValueReaderPool struct {
pool sync.Pool
}
// NewExtJSONValueReaderPool instantiates a new ExtJSONValueReaderPool.
func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool {
return &ExtJSONValueReaderPool{
pool: sync.Pool{
New: func() interface{} {
return new(extJSONValueReader)
},
},
}
}
// Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON.
func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) {
vr := bvrp.pool.Get().(*extJSONValueReader)
return vr.reset(r, canonical)
}
// Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing
// is inserted into the pool and ok will be false.
func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) {
bvr, ok := vr.(*extJSONValueReader)
if !ok {
return false
}
bvr, _ = bvr.reset(nil, false)
bvrp.pool.Put(bvr)
return true
}
type ejvrState struct {
mode mode
vType bsontype.Type
depth int
}
// extJSONValueReader is for reading extended JSON.
type extJSONValueReader struct {
p *extJSONParser
stack []ejvrState
frame int
}
// NewExtJSONValueReader creates a new ValueReader from a given io.Reader
// It will interpret the JSON of r as canonical or relaxed according to the
// given canonical flag
func NewExtJSONValueReader(r io.Reader, canonical bool) (ValueReader, error) {
return newExtJSONValueReader(r, canonical)
}
func newExtJSONValueReader(r io.Reader, canonical bool) (*extJSONValueReader, error) {
ejvr := new(extJSONValueReader)
return ejvr.reset(r, canonical)
}
func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) (*extJSONValueReader, error) {
p := newExtJSONParser(r, canonical)
typ, err := p.peekType()
if err != nil {
return nil, ErrInvalidJSON
}
var m mode
switch typ {
case bsontype.EmbeddedDocument:
m = mTopLevel
case bsontype.Array:
m = mArray
default:
m = mValue
}
stack := make([]ejvrState, 1, 5)
stack[0] = ejvrState{
mode: m,
vType: typ,
}
return &extJSONValueReader{
p: p,
stack: stack,
}, nil
}
func (ejvr *extJSONValueReader) advanceFrame() {
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
length := len(ejvr.stack)
if length+1 >= cap(ejvr.stack) {
// double it
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
copy(buf, ejvr.stack)
ejvr.stack = buf
}
ejvr.stack = ejvr.stack[:length+1]
}
ejvr.frame++
// Clean the stack
ejvr.stack[ejvr.frame].mode = 0
ejvr.stack[ejvr.frame].vType = 0
ejvr.stack[ejvr.frame].depth = 0
}
func (ejvr *extJSONValueReader) pushDocument() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mDocument
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
}
func (ejvr *extJSONValueReader) pushCodeWithScope() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mCodeWithScope
}
func (ejvr *extJSONValueReader) pushArray() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mArray
}
func (ejvr *extJSONValueReader) push(m mode, t bsontype.Type) {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = m
ejvr.stack[ejvr.frame].vType = t
}
func (ejvr *extJSONValueReader) pop() {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
ejvr.frame--
case mDocument, mArray, mCodeWithScope:
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (ejvr *extJSONValueReader) skipDocument() error {
// read entire document until ErrEOD (using readKey and readValue)
_, typ, err := ejvr.p.readKey()
for err == nil {
_, err = ejvr.p.readValue(typ)
if err != nil {
break
}
_, typ, err = ejvr.p.readKey()
}
return err
}
func (ejvr *extJSONValueReader) skipArray() error {
// read entire array until ErrEOA (using peekType)
_, err := ejvr.p.peekType()
for err == nil {
_, err = ejvr.p.peekType()
}
return err
}
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvr.stack[ejvr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if ejvr.frame != 0 {
te.parent = ejvr.stack[ejvr.frame-1].mode
}
return te
}
func (ejvr *extJSONValueReader) typeError(t bsontype.Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
}
func (ejvr *extJSONValueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string, addModes ...mode) error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != t {
return ejvr.typeError(t)
}
default:
modes := []mode{mElement, mValue}
if addModes != nil {
modes = append(modes, addModes...)
}
return ejvr.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvr *extJSONValueReader) Type() bsontype.Type {
return ejvr.stack[ejvr.frame].vType
}
func (ejvr *extJSONValueReader) Skip() error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
default:
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
defer ejvr.pop()
t := ejvr.stack[ejvr.frame].vType
switch t {
case bsontype.Array:
// read entire array until ErrEOA
err := ejvr.skipArray()
if err != ErrEOA {
return err
}
case bsontype.EmbeddedDocument:
// read entire doc until ErrEOD
err := ejvr.skipDocument()
if err != ErrEOD {
return err
}
case bsontype.CodeWithScope:
// read the code portion and set up parser in document mode
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
// read until ErrEOD
err = ejvr.skipDocument()
if err != ErrEOD {
return err
}
default:
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
}
return nil
}
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel: // allow reading array from top level
case mArray:
return ejvr, nil
default:
if err := ejvr.ensureElementValue(bsontype.Array, mArray, "ReadArray", mTopLevel, mArray); err != nil {
return nil, err
}
}
ejvr.pushArray()
return ejvr, nil
}
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := ejvr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
v, err := ejvr.p.readValue(bsontype.Binary)
if err != nil {
return nil, 0, err
}
b, btype, err = v.parseBinary()
ejvr.pop()
return b, btype, err
}
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
if err := ejvr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil {
return false, err
}
v, err := ejvr.p.readValue(bsontype.Boolean)
if err != nil {
return false, err
}
if v.t != bsontype.Boolean {
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
}
ejvr.pop()
return v.v.(bool), nil
}
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel:
return ejvr, nil
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != bsontype.EmbeddedDocument {
return nil, ejvr.typeError(bsontype.EmbeddedDocument)
}
ejvr.pushDocument()
return ejvr, nil
default:
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
}
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err = ejvr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
v, err := ejvr.p.readValue(bsontype.CodeWithScope)
if err != nil {
return "", nil, err
}
code, err = v.parseJavascript()
ejvr.pushCodeWithScope()
return code, ejvr, err
}
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
if err = ejvr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil {
return "", primitive.NilObjectID, err
}
v, err := ejvr.p.readValue(bsontype.DBPointer)
if err != nil {
return "", primitive.NilObjectID, err
}
ns, oid, err = v.parseDBPointer()
ejvr.pop()
return ns, oid, err
}
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
if err := ejvr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.DateTime)
if err != nil {
return 0, err
}
d, err := v.parseDateTime()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDecimal128() (primitive.Decimal128, error) {
if err := ejvr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil {
return primitive.Decimal128{}, err
}
v, err := ejvr.p.readValue(bsontype.Decimal128)
if err != nil {
return primitive.Decimal128{}, err
}
d, err := v.parseDecimal128()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
if err := ejvr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Double)
if err != nil {
return 0, err
}
d, err := v.parseDouble()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
if err := ejvr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Int32)
if err != nil {
return 0, err
}
i, err := v.parseInt32()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
if err := ejvr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Int64)
if err != nil {
return 0, err
}
i, err := v.parseInt64()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
if err = ejvr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.JavaScript)
if err != nil {
return "", err
}
code, err = v.parseJavascript()
ejvr.pop()
return code, err
}
func (ejvr *extJSONValueReader) ReadMaxKey() error {
if err := ejvr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.MaxKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("max")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadMinKey() error {
if err := ejvr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.MinKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("min")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadNull() error {
if err := ejvr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.Null)
if err != nil {
return err
}
if v.t != bsontype.Null {
return fmt.Errorf("expected type null but got type %s", v.t)
}
ejvr.pop()
return nil
}
func (ejvr *extJSONValueReader) ReadObjectID() (primitive.ObjectID, error) {
if err := ejvr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil {
return primitive.ObjectID{}, err
}
v, err := ejvr.p.readValue(bsontype.ObjectID)
if err != nil {
return primitive.ObjectID{}, err
}
oid, err := v.parseObjectID()
ejvr.pop()
return oid, err
}
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
if err = ejvr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil {
return "", "", err
}
v, err := ejvr.p.readValue(bsontype.Regex)
if err != nil {
return "", "", err
}
pattern, options, err = v.parseRegex()
ejvr.pop()
return pattern, options, err
}
func (ejvr *extJSONValueReader) ReadString() (string, error) {
if err := ejvr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.String)
if err != nil {
return "", err
}
if v.t != bsontype.String {
return "", fmt.Errorf("expected type string but got type %s", v.t)
}
ejvr.pop()
return v.v.(string), nil
}
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
if err = ejvr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.Symbol)
if err != nil {
return "", err
}
symbol, err = v.parseSymbol()
ejvr.pop()
return symbol, err
}
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err = ejvr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
v, err := ejvr.p.readValue(bsontype.Timestamp)
if err != nil {
return 0, 0, err
}
t, i, err = v.parseTimestamp()
ejvr.pop()
return t, i, err
}
func (ejvr *extJSONValueReader) ReadUndefined() error {
if err := ejvr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.Undefined)
if err != nil {
return err
}
err = v.parseUndefined()
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
name, t, err := ejvr.p.readKey()
if err != nil {
if err == ErrEOD {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
return "", nil, err
}
}
ejvr.pop()
}
return "", nil, err
}
ejvr.push(mElement, t)
return name, ejvr, nil
}
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mArray:
default:
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := ejvr.p.peekType()
if err != nil {
if err == ErrEOA {
ejvr.pop()
}
return nil, err
}
ejvr.push(mValue, t)
return ejvr, nil
}

View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsonrw
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

View File

@@ -0,0 +1,481 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func wrapperKeyBSONType(key string) bsontype.Type {
switch string(key) {
case "$numberInt":
return bsontype.Int32
case "$numberLong":
return bsontype.Int64
case "$oid":
return bsontype.ObjectID
case "$symbol":
return bsontype.Symbol
case "$numberDouble":
return bsontype.Double
case "$numberDecimal":
return bsontype.Decimal128
case "$binary":
return bsontype.Binary
case "$code":
return bsontype.JavaScript
case "$scope":
return bsontype.CodeWithScope
case "$timestamp":
return bsontype.Timestamp
case "$regularExpression":
return bsontype.Regex
case "$dbPointer":
return bsontype.DBPointer
case "$date":
return bsontype.DateTime
case "$ref":
fallthrough
case "$id":
fallthrough
case "$db":
return bsontype.EmbeddedDocument // dbrefs aren't bson types
case "$minKey":
return bsontype.MinKey
case "$maxKey":
return bsontype.MaxKey
case "$undefined":
return bsontype.Undefined
}
return bsontype.EmbeddedDocument
}
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
}
binObj := ejv.v.(*extJSONObject)
bFound := false
stFound := false
for i, key := range binObj.keys {
val := binObj.values[i]
switch key {
case "base64":
if bFound {
return nil, 0, errors.New("duplicate base64 key in $binary")
}
if val.t != bsontype.String {
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
}
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
}
b = base64Bytes
bFound = true
case "subType":
if stFound {
return nil, 0, errors.New("duplicate subType key in $binary")
}
if val.t != bsontype.String {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}
i, err := strconv.ParseInt(val.v.(string), 16, 64)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
}
subType = byte(i)
stFound = true
default:
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
}
}
if !bFound {
return nil, 0, errors.New("missing base64 field in $binary object")
}
if !stFound {
return nil, 0, errors.New("missing subType field in $binary object")
}
return b, subType, nil
}
func (ejv *extJSONValue) parseDBPointer() (ns string, oid primitive.ObjectID, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
}
dbpObj := ejv.v.(*extJSONObject)
oidFound := false
nsFound := false
for i, key := range dbpObj.keys {
val := dbpObj.values[i]
switch key {
case "$ref":
if nsFound {
return "", primitive.NilObjectID, errors.New("duplicate $ref key in $dbPointer")
}
if val.t != bsontype.String {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
}
ns = val.v.(string)
nsFound = true
case "$id":
if oidFound {
return "", primitive.NilObjectID, errors.New("duplicate $id key in $dbPointer")
}
if val.t != bsontype.String {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
}
oid, err = primitive.ObjectIDFromHex(val.v.(string))
if err != nil {
return "", primitive.NilObjectID, err
}
oidFound = true
default:
return "", primitive.NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
}
}
if !nsFound {
return "", oid, errors.New("missing $ref field in $dbPointer object")
}
if !oidFound {
return "", oid, errors.New("missing $id field in $dbPointer object")
}
return ns, oid, nil
}
const rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
func (ejv *extJSONValue) parseDateTime() (int64, error) {
switch ejv.t {
case bsontype.Int32:
return int64(ejv.v.(int32)), nil
case bsontype.Int64:
return ejv.v.(int64), nil
case bsontype.String:
return parseDatetimeString(ejv.v.(string))
case bsontype.EmbeddedDocument:
return parseDatetimeObject(ejv.v.(*extJSONObject))
default:
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
}
}
func parseDatetimeString(data string) (int64, error) {
t, err := time.Parse(rfc3339Milli, data)
if err != nil {
return 0, fmt.Errorf("invalid $date value string: %s", data)
}
return t.Unix()*1e3 + int64(t.Nanosecond())/1e6, nil
}
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
dFound := false
for i, key := range data.keys {
val := data.values[i]
switch key {
case "$numberLong":
if dFound {
return 0, errors.New("duplicate $numberLong key in $date")
}
if val.t != bsontype.String {
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
}
d, err = val.parseInt64()
if err != nil {
return 0, err
}
dFound = true
default:
return 0, fmt.Errorf("invalid key in $date object: %s", key)
}
}
if !dFound {
return 0, errors.New("missing $numberLong field in $date object")
}
return d, nil
}
func (ejv *extJSONValue) parseDecimal128() (primitive.Decimal128, error) {
if ejv.t != bsontype.String {
return primitive.Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
}
d, err := primitive.ParseDecimal128(ejv.v.(string))
if err != nil {
return primitive.Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
}
return d, nil
}
func (ejv *extJSONValue) parseDouble() (float64, error) {
if ejv.t == bsontype.Double {
return ejv.v.(float64), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
}
switch string(ejv.v.(string)) {
case "Infinity":
return math.Inf(1), nil
case "-Infinity":
return math.Inf(-1), nil
case "NaN":
return math.NaN(), nil
}
f, err := strconv.ParseFloat(ejv.v.(string), 64)
if err != nil {
return 0, err
}
return f, nil
}
func (ejv *extJSONValue) parseInt32() (int32, error) {
if ejv.t == bsontype.Int32 {
return ejv.v.(int32), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
if i < math.MinInt32 || i > math.MaxInt32 {
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
}
return int32(i), nil
}
func (ejv *extJSONValue) parseInt64() (int64, error) {
if ejv.t == bsontype.Int64 {
return ejv.v.(int64), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
if ejv.t != bsontype.String {
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
if ejv.t != bsontype.Int32 {
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
}
if ejv.v.(int32) != 1 {
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
}
return nil
}
func (ejv *extJSONValue) parseObjectID() (primitive.ObjectID, error) {
if ejv.t != bsontype.String {
return primitive.NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
}
return primitive.ObjectIDFromHex(ejv.v.(string))
}
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
}
regexObj := ejv.v.(*extJSONObject)
patFound := false
optFound := false
for i, key := range regexObj.keys {
val := regexObj.values[i]
switch string(key) {
case "pattern":
if patFound {
return "", "", errors.New("duplicate pattern key in $regularExpression")
}
if val.t != bsontype.String {
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
}
pattern = val.v.(string)
patFound = true
case "options":
if optFound {
return "", "", errors.New("duplicate options key in $regularExpression")
}
if val.t != bsontype.String {
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
}
options = val.v.(string)
optFound = true
default:
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
}
}
if !patFound {
return "", "", errors.New("missing pattern field in $regularExpression object")
}
if !optFound {
return "", "", errors.New("missing options field in $regularExpression object")
}
return pattern, options, nil
}
func (ejv *extJSONValue) parseSymbol() (string, error) {
if ejv.t != bsontype.String {
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
}
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
if flag {
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
}
switch val.t {
case bsontype.Int32:
if val.v.(int32) < 0 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %s", key, string(val.v.(int32)))
}
return uint32(val.v.(int32)), nil
case bsontype.Int64:
if val.v.(int64) < 0 || uint32(val.v.(int64)) > math.MaxUint32 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %s", key, string(val.v.(int32)))
}
return uint32(val.v.(int64)), nil
default:
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
}
}
tsObj := ejv.v.(*extJSONObject)
tFound := false
iFound := false
for j, key := range tsObj.keys {
val := tsObj.values[j]
switch key {
case "t":
if t, err = handleKey(key, val, tFound); err != nil {
return 0, 0, err
}
tFound = true
case "i":
if i, err = handleKey(key, val, iFound); err != nil {
return 0, 0, err
}
iFound = true
default:
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
}
}
if !tFound {
return 0, 0, errors.New("missing t field in $timestamp object")
}
if !iFound {
return 0, 0, errors.New("missing i field in $timestamp object")
}
return t, i, nil
}
func (ejv *extJSONValue) parseUndefined() error {
if ejv.t != bsontype.Boolean {
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
}
if !ejv.v.(bool) {
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
}
return nil
}

View File

@@ -0,0 +1,734 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"encoding/base64"
"fmt"
"go.mongodb.org/mongo-driver/bson/primitive"
"io"
"math"
"sort"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
var ejvwPool = sync.Pool{
New: func() interface{} {
return new(extJSONValueWriter)
},
}
// ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters.
type ExtJSONValueWriterPool struct {
pool sync.Pool
}
// NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON.
func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool {
return &ExtJSONValueWriterPool{
pool: sync.Pool{
New: func() interface{} {
return new(extJSONValueWriter)
},
},
}
}
// Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination.
func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter {
vw := bvwp.pool.Get().(*extJSONValueWriter)
if writer, ok := w.(*SliceWriter); ok {
vw.reset(*writer, canonical, escapeHTML)
vw.w = writer
return vw
}
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing
// happens and ok will be false.
func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
bvw, ok := vw.(*extJSONValueWriter)
if !ok {
return false
}
if _, ok := bvw.w.(*SliceWriter); ok {
bvw.buf = nil
}
bvw.w = nil
bvwp.pool.Put(bvw)
return true
}
type ejvwState struct {
mode mode
}
type extJSONValueWriter struct {
w io.Writer
buf []byte
stack []ejvwState
frame int64
canonical bool
escapeHTML bool
}
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) {
if w == nil {
return nil, errNilWriter
}
return newExtJSONWriter(w, canonical, escapeHTML), nil
}
func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
w: w,
buf: []byte{},
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
buf: buf,
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
if ejvw.stack == nil {
ejvw.stack = make([]ejvwState, 1, 5)
}
ejvw.stack = ejvw.stack[:1]
ejvw.stack[0] = ejvwState{mode: mTopLevel}
ejvw.canonical = canonical
ejvw.escapeHTML = escapeHTML
ejvw.frame = 0
ejvw.buf = buf
ejvw.w = nil
}
func (ejvw *extJSONValueWriter) advanceFrame() {
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
length := len(ejvw.stack)
if length+1 >= cap(ejvw.stack) {
// double it
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
copy(buf, ejvw.stack)
ejvw.stack = buf
}
ejvw.stack = ejvw.stack[:length+1]
}
ejvw.frame++
}
func (ejvw *extJSONValueWriter) push(m mode) {
ejvw.advanceFrame()
ejvw.stack[ejvw.frame].mode = m
}
func (ejvw *extJSONValueWriter) pop() {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
ejvw.frame--
case mDocument, mArray, mCodeWithScope:
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvw.stack[ejvw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if ejvw.frame != 0 {
te.parent = ejvw.stack[ejvw.frame-1].mode
}
return te
}
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return ejvw.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
var s string
if quotes {
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
} else {
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '[')
ejvw.push(mArray)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
return ejvw.WriteBinaryWithSubtype(b, 0x00)
}
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$binary":{"base64":"`)
buf.WriteString(base64.StdEncoding.EncodeToString(b))
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString(`{"$code":`)
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
buf.WriteString(`,"$scope":{`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.push(mCodeWithScope)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$dbPointer":{"$ref":"`)
buf.WriteString(ns)
buf.WriteString(`","$id":{"$oid":"`)
buf.WriteString(oid.Hex())
buf.WriteString(`"}}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
return err
}
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
ejvw.writeExtendedSingleValue("date", s, false)
} else {
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
if ejvw.stack[ejvw.frame].mode == mTopLevel {
ejvw.buf = append(ejvw.buf, '{')
return ejvw, nil
}
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '{')
ejvw.push(mDocument)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
return err
}
s := formatDouble(f)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberDouble", s, true)
} else {
switch s {
case "Infinity":
fallthrough
case "-Infinity":
fallthrough
case "NaN":
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
return err
}
s := strconv.FormatInt(int64(i), 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberInt", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
return err
}
s := strconv.FormatInt(i, 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberLong", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("code", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("maxKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMinKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("minKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteNull() error {
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte("null")...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$regularExpression":{"pattern":`)
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
buf.WriteString(`,"options":"`)
buf.WriteString(sortStringAlphebeticAscending(options))
buf.WriteString(`"}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteString(s string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$timestamp":{"t":`)
buf.WriteString(strconv.FormatUint(uint64(t), 10))
buf.WriteString(`,"i":`)
buf.WriteString(strconv.FormatUint(uint64(i), 10))
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteUndefined() error {
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("undefined", "true", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`"%s":`, key))...)
ejvw.push(mElement)
default:
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
default:
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
}
// close the document
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = '}'
} else {
ejvw.buf = append(ejvw.buf, '}')
}
switch ejvw.stack[ejvw.frame].mode {
case mCodeWithScope:
ejvw.buf = append(ejvw.buf, '}')
fallthrough
case mDocument:
ejvw.buf = append(ejvw.buf, ',')
case mTopLevel:
if ejvw.w != nil {
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
return err
}
ejvw.buf = ejvw.buf[:0]
}
}
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
ejvw.push(mValue)
default:
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
// close the array
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = ']'
} else {
ejvw.buf = append(ejvw.buf, ']')
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
default:
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
}
return nil
}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
var hexChars = "0123456789abcdef"
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
oldI := ss[i]
ss[i] = ss[j]
ss[j] = oldI
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}

View File

@@ -0,0 +1,439 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"strconv"
"strings"
"unicode"
)
type jsonTokenType byte
const (
jttBeginObject jsonTokenType = iota
jttEndObject
jttBeginArray
jttEndArray
jttColon
jttComma
jttInt32
jttInt64
jttDouble
jttString
jttBool
jttNull
jttEOF
)
type jsonToken struct {
t jsonTokenType
v interface{}
p int
}
type jsonScanner struct {
r io.Reader
buf []byte
pos int
lastReadErr error
}
// nextToken returns the next JSON token if one exists. A token is a character
// of the JSON grammar, a number, a string, or a literal.
func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err := js.readNextByte()
// keep reading until a non-space is encountered (break on read error or EOF)
for isWhiteSpace(c) && err == nil {
c, err = js.readNextByte()
}
if err == io.EOF {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
}
// switch on the character
switch c {
case '{':
return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
case '}':
return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
case '[':
return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
case ']':
return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
case ':':
return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
case ',':
return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
case '"': // RFC-8259 only allows for double quotes (") not single (')
return js.scanString()
default:
// check if it's a number
if c == '-' || isDigit(c) {
return js.scanNumber(c)
} else if c == 't' || c == 'f' || c == 'n' {
// maybe a literal
return js.scanLiteral(c)
} else {
return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
}
}
}
// readNextByte attempts to read the next byte from the buffer. If the buffer
// has been exhausted, this function calls readIntoBuf, thus refilling the
// buffer and resetting the read position to 0
func (js *jsonScanner) readNextByte() (byte, error) {
if js.pos >= len(js.buf) {
err := js.readIntoBuf()
if err != nil {
return 0, err
}
}
b := js.buf[js.pos]
js.pos++
return b, nil
}
// readNNextBytes reads n bytes into dst, starting at offset
func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
var err error
for i := 0; i < n; i++ {
dst[i+offset], err = js.readNextByte()
if err != nil {
return err
}
}
return nil
}
// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
func (js *jsonScanner) readIntoBuf() error {
if js.lastReadErr != nil {
js.buf = js.buf[:0]
js.pos = 0
return js.lastReadErr
}
if cap(js.buf) == 0 {
js.buf = make([]byte, 0, 512)
}
n, err := js.r.Read(js.buf[:cap(js.buf)])
if err != nil {
js.lastReadErr = err
if n > 0 {
err = nil
}
}
js.buf = js.buf[:n]
js.pos = 0
return err
}
func isWhiteSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
func isDigit(c byte) bool {
return unicode.IsDigit(rune(c))
}
func isValueTerminator(c byte) bool {
return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
}
// scanString reads from an opening '"' to a closing '"' and handles escaped characters
func (js *jsonScanner) scanString() (*jsonToken, error) {
var b bytes.Buffer
var c byte
var err error
p := js.pos - 1
for {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
switch c {
case '\\':
c, err = js.readNextByte()
switch c {
case '"', '\\', '/', '\'':
b.WriteByte(c)
case 'b':
b.WriteByte('\b')
case 'f':
b.WriteByte('\f')
case 'n':
b.WriteByte('\n')
case 'r':
b.WriteByte('\r')
case 't':
b.WriteByte('\t')
case 'u':
us := make([]byte, 4)
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
s := fmt.Sprintf(`\u%s`, us)
s, err = strconv.Unquote(strings.Replace(strconv.Quote(s), `\\u`, `\u`, 1))
if err != nil {
return nil, err
}
b.WriteString(s)
default:
return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
}
case '"':
return &jsonToken{t: jttString, v: b.String(), p: p}, nil
default:
b.WriteByte(c)
}
}
}
// scanLiteral reads an unquoted sequence of characters and determines if it is one of
// three valid JSON literals (true, false, null); if so, it returns the appropriate
// jsonToken; otherwise, it returns an error
func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
p := js.pos - 1
lit := make([]byte, 4)
lit[0] = first
err := js.readNNextBytes(lit, 3, 1)
if err != nil {
return nil, err
}
c5, err := js.readNextByte()
if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: true, p: p}, nil
} else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttNull, v: nil, p: p}, nil
} else if bytes.Equal([]byte("fals"), lit) {
if c5 == 'e' {
c5, err = js.readNextByte()
if isValueTerminator(c5) || err == io.EOF {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: false, p: p}, nil
}
}
}
return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
}
type numberScanState byte
const (
nssSawLeadingMinus numberScanState = iota
nssSawLeadingZero
nssSawIntegerDigits
nssSawDecimalPoint
nssSawFractionDigits
nssSawExponentLetter
nssSawExponentSign
nssSawExponentDigits
nssDone
nssInvalid
)
// scanNumber reads a JSON number (according to RFC-8259)
func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
var b bytes.Buffer
var s numberScanState
var c byte
var err error
t := jttInt64 // assume it's an int64 until the type can be determined
start := js.pos - 1
b.WriteByte(first)
switch first {
case '-':
s = nssSawLeadingMinus
case '0':
s = nssSawLeadingZero
default:
s = nssSawIntegerDigits
}
for {
c, err = js.readNextByte()
if err != nil && err != io.EOF {
return nil, err
}
switch s {
case nssSawLeadingMinus:
switch c {
case '0':
s = nssSawLeadingZero
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawLeadingZero:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else {
s = nssInvalid
}
}
case nssSawIntegerDigits:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawDecimalPoint:
t = jttDouble
if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawFractionDigits:
switch c {
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentLetter:
t = jttDouble
switch c {
case '+', '-':
s = nssSawExponentSign
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentSign:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawExponentDigits:
switch c {
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
}
switch s {
case nssInvalid:
return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
case nssDone:
js.pos = int(math.Max(0, float64(js.pos-1)))
if t != jttDouble {
v, err := strconv.ParseInt(b.String(), 10, 64)
if err == nil {
if v < math.MinInt32 || v > math.MaxInt32 {
return &jsonToken{t: jttInt64, v: v, p: start}, nil
}
return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
}
}
v, err := strconv.ParseFloat(b.String(), 64)
if err != nil {
return nil, err
}
return &jsonToken{t: jttDouble, v: v, p: start}, nil
}
}
}

108
vendor/go.mongodb.org/mongo-driver/bson/bsonrw/mode.go generated vendored Normal file
View File

@@ -0,0 +1,108 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
)
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "DocumentMode"
case mArray:
str = "ArrayMode"
case mValue:
str = "ValueMode"
case mElement:
str = "ElementMode"
case mCodeWithScope:
str = "CodeWithScopeMode"
case mSpacer:
str = "CodeWithScopeSpacerFrame"
default:
str = "UnknownMode"
}
return str
}
func (m mode) TypeString() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "Document"
case mArray:
str = "Array"
case mValue:
str = "Value"
case mElement:
str = "Element"
case mCodeWithScope:
str = "CodeWithScope"
case mSpacer:
str = "CodeWithScopeSpacer"
default:
str = "Unknown"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
// If read is false, the error is for writing
type TransitionError struct {
name string
parent mode
current mode
destination mode
modes []mode
action string
}
func (te TransitionError) Error() string {
errString := fmt.Sprintf("%s can only %s", te.name, te.action)
if te.destination != mode(0) {
errString = fmt.Sprintf("%s a %s", errString, te.destination.TypeString())
}
errString = fmt.Sprintf("%s while positioned on a", errString)
for ind, m := range te.modes {
if ind != 0 && len(te.modes) > 2 {
errString = fmt.Sprintf("%s,", errString)
}
if ind == len(te.modes)-1 && len(te.modes) > 1 {
errString = fmt.Sprintf("%s or", errString)
}
errString = fmt.Sprintf("%s %s", errString, m.TypeString())
}
errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current.TypeString())
if te.parent != mode(0) {
errString = fmt.Sprintf("%s with parent %s", errString, te.parent.TypeString())
}
return errString
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayReader is implemented by types that allow reading values from a BSON
// array.
type ArrayReader interface {
ReadValue() (ValueReader, error)
}
// DocumentReader is implemented by types that allow reading elements from a
// BSON document.
type DocumentReader interface {
ReadElement() (string, ValueReader, error)
}
// ValueReader is a generic interface used to read values from BSON. This type
// is implemented by several types with different underlying representations of
// BSON, such as a bson.Document, raw BSON bytes, or extended JSON.
type ValueReader interface {
Type() bsontype.Type
Skip() error
ReadArray() (ArrayReader, error)
ReadBinary() (b []byte, btype byte, err error)
ReadBoolean() (bool, error)
ReadDocument() (DocumentReader, error)
ReadCodeWithScope() (code string, dr DocumentReader, err error)
ReadDBPointer() (ns string, oid primitive.ObjectID, err error)
ReadDateTime() (int64, error)
ReadDecimal128() (primitive.Decimal128, error)
ReadDouble() (float64, error)
ReadInt32() (int32, error)
ReadInt64() (int64, error)
ReadJavascript() (code string, err error)
ReadMaxKey() error
ReadMinKey() error
ReadNull() error
ReadObjectID() (primitive.ObjectID, error)
ReadRegex() (pattern, options string, err error)
ReadString() (string, error)
ReadSymbol() (symbol string, err error)
ReadTimestamp() (t, i uint32, err error)
ReadUndefined() error
}
// BytesReader is a generic interface used to read BSON bytes from a
// ValueReader. This imterface is meant to be a superset of ValueReader, so that
// types that implement ValueReader may also implement this interface.
//
// The bytes of the value will be appended to dst.
type BytesReader interface {
ReadValueBytes(dst []byte) (bsontype.Type, []byte, error)
}

View File

@@ -0,0 +1,882 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"sync"
"unicode"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var _ ValueReader = (*valueReader)(nil)
var vrPool = sync.Pool{
New: func() interface{} {
return new(valueReader)
},
}
// BSONValueReaderPool is a pool for ValueReaders that read BSON.
type BSONValueReaderPool struct {
pool sync.Pool
}
// NewBSONValueReaderPool instantiates a new BSONValueReaderPool.
func NewBSONValueReaderPool() *BSONValueReaderPool {
return &BSONValueReaderPool{
pool: sync.Pool{
New: func() interface{} {
return new(valueReader)
},
},
}
}
// Get retrieves a ValueReader from the pool and uses src as the underlying BSON.
func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader {
vr := bvrp.pool.Get().(*valueReader)
vr.reset(src)
return vr
}
// Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing
// is inserted into the pool and ok will be false.
func (bvrp *BSONValueReaderPool) Put(vr ValueReader) (ok bool) {
bvr, ok := vr.(*valueReader)
if !ok {
return false
}
bvr.reset(nil)
bvrp.pool.Put(bvr)
return true
}
// ErrEOA is the error returned when the end of a BSON array has been reached.
var ErrEOA = errors.New("end of array")
// ErrEOD is the error returned when the end of a BSON document has been reached.
var ErrEOD = errors.New("end of document")
type vrState struct {
mode mode
vType bsontype.Type
end int64
}
// valueReader is for reading BSON values.
type valueReader struct {
offset int64
d []byte
stack []vrState
frame int64
}
// NewBSONDocumentReader returns a ValueReader using b for the underlying BSON
// representation. Parameter b must be a BSON Document.
//
// TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes
// a []byte while the writer takes an io.Writer. We should have two versions of each, one that takes
// a []byte and one that takes an io.Reader or io.Writer. The []byte version will need to return a
// thing that can return the finished []byte since it might be reallocated when appended to.
func NewBSONDocumentReader(b []byte) ValueReader {
return newValueReader(b)
}
// NewBSONValueReader returns a ValueReader that starts in the Value mode instead of in top
// level document mode. This enables the creation of a ValueReader for a single BSON value.
func NewBSONValueReader(t bsontype.Type, val []byte) ValueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mValue,
vType: t,
}
return &valueReader{
d: val,
stack: stack,
}
}
func newValueReader(b []byte) *valueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mTopLevel,
}
return &valueReader{
d: b,
stack: stack,
}
}
func (vr *valueReader) reset(b []byte) {
if vr.stack == nil {
vr.stack = make([]vrState, 1, 5)
}
vr.stack = vr.stack[:1]
vr.stack[0] = vrState{mode: mTopLevel}
vr.d = b
vr.offset = 0
vr.frame = 0
}
func (vr *valueReader) advanceFrame() {
if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack
length := len(vr.stack)
if length+1 >= cap(vr.stack) {
// double it
buf := make([]vrState, 2*cap(vr.stack)+1)
copy(buf, vr.stack)
vr.stack = buf
}
vr.stack = vr.stack[:length+1]
}
vr.frame++
// Clean the stack
vr.stack[vr.frame].mode = 0
vr.stack[vr.frame].vType = 0
vr.stack[vr.frame].end = 0
}
func (vr *valueReader) pushDocument() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mDocument
size, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return nil
}
func (vr *valueReader) pushArray() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mArray
size, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return nil
}
func (vr *valueReader) pushElement(t bsontype.Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mElement
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushValue(t bsontype.Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mValue
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushCodeWithScope() (int64, error) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mCodeWithScope
size, err := vr.readLength()
if err != nil {
return 0, err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return int64(size), nil
}
func (vr *valueReader) pop() {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
vr.frame--
case mDocument, mArray, mCodeWithScope:
vr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vr.stack[vr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if vr.frame != 0 {
te.parent = vr.stack[vr.frame-1].mode
}
return te
}
func (vr *valueReader) typeError(t bsontype.Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t)
}
func (vr *valueReader) invalidDocumentLengthError() error {
return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset)
}
func (vr *valueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string) error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
if vr.stack[vr.frame].vType != t {
return vr.typeError(t)
}
default:
return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue})
}
return nil
}
func (vr *valueReader) Type() bsontype.Type {
return vr.stack[vr.frame].vType
}
func (vr *valueReader) nextElementLength() (int32, error) {
var length int32
var err error
switch vr.stack[vr.frame].vType {
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
length, err = vr.peekLength()
case bsontype.Binary:
length, err = vr.peekLength()
length += 4 + 1 // binary length + subtype byte
case bsontype.Boolean:
length = 1
case bsontype.DBPointer:
length, err = vr.peekLength()
length += 4 + 12 // string length + ObjectID length
case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp:
length = 8
case bsontype.Decimal128:
length = 16
case bsontype.Int32:
length = 4
case bsontype.JavaScript, bsontype.String, bsontype.Symbol:
length, err = vr.peekLength()
length += 4
case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined:
length = 0
case bsontype.ObjectID:
length = 12
case bsontype.Regex:
regex := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if regex < 0 {
err = io.EOF
break
}
pattern := bytes.IndexByte(vr.d[vr.offset+int64(regex)+1:], 0x00)
if pattern < 0 {
err = io.EOF
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1)
default:
return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType)
}
return length, err
}
func (vr *valueReader) ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.peekLength()
if err != nil {
return bsontype.Type(0), nil, err
}
dst, err = vr.appendBytes(dst, length)
if err != nil {
return bsontype.Type(0), nil, err
}
return bsontype.Type(0), dst, nil
case mElement, mValue:
length, err := vr.nextElementLength()
if err != nil {
return bsontype.Type(0), dst, err
}
dst, err = vr.appendBytes(dst, length)
t := vr.stack[vr.frame].vType
vr.pop()
return t, dst, err
default:
return bsontype.Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue})
}
}
func (vr *valueReader) Skip() error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
default:
return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
length, err := vr.nextElementLength()
if err != nil {
return err
}
err = vr.skipBytes(length)
vr.pop()
return err
}
func (vr *valueReader) ReadArray() (ArrayReader, error) {
if err := vr.ensureElementValue(bsontype.Array, mArray, "ReadArray"); err != nil {
return nil, err
}
err := vr.pushArray()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := vr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
length, err := vr.readLength()
if err != nil {
return nil, 0, err
}
btype, err = vr.readByte()
if err != nil {
return nil, 0, err
}
if btype == 0x02 {
length, err = vr.readLength()
if err != nil {
return nil, 0, err
}
}
b, err = vr.readBytes(length)
if err != nil {
return nil, 0, err
}
vr.pop()
return b, btype, nil
}
func (vr *valueReader) ReadBoolean() (bool, error) {
if err := vr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil {
return false, err
}
b, err := vr.readByte()
if err != nil {
return false, err
}
if b > 1 {
return false, fmt.Errorf("invalid byte for boolean, %b", b)
}
vr.pop()
return b == 1, nil
}
func (vr *valueReader) ReadDocument() (DocumentReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
// read size
size, err := vr.readLength()
if err != nil {
return nil, err
}
if int(size) != len(vr.d) {
return nil, fmt.Errorf("invalid document length")
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return vr, nil
case mElement, mValue:
if vr.stack[vr.frame].vType != bsontype.EmbeddedDocument {
return nil, vr.typeError(bsontype.EmbeddedDocument)
}
default:
return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
err := vr.pushDocument()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err := vr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
totalLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strBytes, err := vr.readBytes(strLength)
if err != nil {
return "", nil, err
}
code = string(strBytes[:len(strBytes)-1])
size, err := vr.pushCodeWithScope()
if err != nil {
return "", nil, err
}
// The total length should equal:
// 4 (total length) + strLength + 4 (the length of str itself) + (document length)
componentsLength := int64(4+strLength+4) + size
if int64(totalLength) != componentsLength {
return "", nil, fmt.Errorf(
"length of CodeWithScope does not match lengths of components; total: %d; components: %d",
totalLength, componentsLength,
)
}
return code, vr, nil
}
func (vr *valueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
if err := vr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil {
return "", oid, err
}
ns, err = vr.readString()
if err != nil {
return "", oid, err
}
oidbytes, err := vr.readBytes(12)
if err != nil {
return "", oid, err
}
copy(oid[:], oidbytes)
vr.pop()
return ns, oid, nil
}
func (vr *valueReader) ReadDateTime() (int64, error) {
if err := vr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
i, err := vr.readi64()
if err != nil {
return 0, err
}
vr.pop()
return i, nil
}
func (vr *valueReader) ReadDecimal128() (primitive.Decimal128, error) {
if err := vr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil {
return primitive.Decimal128{}, err
}
b, err := vr.readBytes(16)
if err != nil {
return primitive.Decimal128{}, err
}
l := binary.LittleEndian.Uint64(b[0:8])
h := binary.LittleEndian.Uint64(b[8:16])
vr.pop()
return primitive.NewDecimal128(h, l), nil
}
func (vr *valueReader) ReadDouble() (float64, error) {
if err := vr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil {
return 0, err
}
u, err := vr.readu64()
if err != nil {
return 0, err
}
vr.pop()
return math.Float64frombits(u), nil
}
func (vr *valueReader) ReadInt32() (int32, error) {
if err := vr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil {
return 0, err
}
vr.pop()
return vr.readi32()
}
func (vr *valueReader) ReadInt64() (int64, error) {
if err := vr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil {
return 0, err
}
vr.pop()
return vr.readi64()
}
func (vr *valueReader) ReadJavascript() (code string, err error) {
if err := vr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadMaxKey() error {
if err := vr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadMinKey() error {
if err := vr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadNull() error {
if err := vr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadObjectID() (primitive.ObjectID, error) {
if err := vr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil {
return primitive.ObjectID{}, err
}
oidbytes, err := vr.readBytes(12)
if err != nil {
return primitive.ObjectID{}, err
}
var oid primitive.ObjectID
copy(oid[:], oidbytes)
vr.pop()
return oid, nil
}
func (vr *valueReader) ReadRegex() (string, string, error) {
if err := vr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil {
return "", "", err
}
pattern, err := vr.readCString()
if err != nil {
return "", "", err
}
options, err := vr.readCString()
if err != nil {
return "", "", err
}
vr.pop()
return pattern, options, nil
}
func (vr *valueReader) ReadString() (string, error) {
if err := vr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadSymbol() (symbol string, err error) {
if err := vr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err := vr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
i, err = vr.readu32()
if err != nil {
return 0, 0, err
}
t, err = vr.readu32()
if err != nil {
return 0, 0, err
}
vr.pop()
return t, i, nil
}
func (vr *valueReader) ReadUndefined() error {
if err := vr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadElement() (string, ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
t, err := vr.readByte()
if err != nil {
return "", nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return "", nil, vr.invalidDocumentLengthError()
}
vr.pop()
return "", nil, ErrEOD
}
name, err := vr.readCString()
if err != nil {
return "", nil, err
}
vr.pushElement(bsontype.Type(t))
return name, vr, nil
}
func (vr *valueReader) ReadValue() (ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mArray:
default:
return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := vr.readByte()
if err != nil {
return nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return nil, vr.invalidDocumentLengthError()
}
vr.pop()
return nil, ErrEOA
}
_, err = vr.readCString()
if err != nil {
return nil, err
}
vr.pushValue(bsontype.Type(t))
return vr, nil
}
func (vr *valueReader) readBytes(length int32) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("invalid length: %d", length)
}
if vr.offset+int64(length) > int64(len(vr.d)) {
return nil, io.EOF
}
start := vr.offset
vr.offset += int64(length)
return vr.d[start : start+int64(length)], nil
}
func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) {
if vr.offset+int64(length) > int64(len(vr.d)) {
return nil, io.EOF
}
start := vr.offset
vr.offset += int64(length)
return append(dst, vr.d[start:start+int64(length)]...), nil
}
func (vr *valueReader) skipBytes(length int32) error {
if vr.offset+int64(length) > int64(len(vr.d)) {
return io.EOF
}
vr.offset += int64(length)
return nil
}
func (vr *valueReader) readByte() (byte, error) {
if vr.offset+1 > int64(len(vr.d)) {
return 0x0, io.EOF
}
vr.offset++
return vr.d[vr.offset-1], nil
}
func (vr *valueReader) readCString() (string, error) {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return "", io.EOF
}
start := vr.offset
// idx does not include the null byte
vr.offset += int64(idx) + 1
return string(vr.d[start : start+int64(idx)]), nil
}
func (vr *valueReader) skipCString() error {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return io.EOF
}
// idx does not include the null byte
vr.offset += int64(idx) + 1
return nil
}
func (vr *valueReader) readString() (string, error) {
length, err := vr.readLength()
if err != nil {
return "", err
}
if int64(length)+vr.offset > int64(len(vr.d)) {
return "", io.EOF
}
if length <= 0 {
return "", fmt.Errorf("invalid string length: %d", length)
}
if vr.d[vr.offset+int64(length)-1] != 0x00 {
return "", fmt.Errorf("string does not end with null byte, but with %v", vr.d[vr.offset+int64(length)-1])
}
start := vr.offset
vr.offset += int64(length)
if length == 2 {
asciiByte := vr.d[start]
if asciiByte > unicode.MaxASCII {
return "", fmt.Errorf("invalid ascii byte")
}
}
return string(vr.d[start : start+int64(length)-1]), nil
}
func (vr *valueReader) peekLength() (int32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }
func (vr *valueReader) readi32() (int32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 4
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readu32() (uint32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 4
return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readi64() (int64, error) {
if vr.offset+8 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 8
return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 |
int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil
}
func (vr *valueReader) readu64() (uint64, error) {
if vr.offset+8 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 8
return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 |
uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil
}

View File

@@ -0,0 +1,589 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"errors"
"fmt"
"io"
"math"
"strconv"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var _ ValueWriter = (*valueWriter)(nil)
var vwPool = sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
}
// BSONValueWriterPool is a pool for BSON ValueWriters.
type BSONValueWriterPool struct {
pool sync.Pool
}
// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
func NewBSONValueWriterPool() *BSONValueWriterPool {
return &BSONValueWriterPool{
pool: sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
},
}
}
// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
vw := bvwp.pool.Get().(*valueWriter)
if writer, ok := w.(*SliceWriter); ok {
vw.reset(*writer)
vw.w = writer
return vw
}
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
// happens and ok will be false.
func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
bvw, ok := vw.(*valueWriter)
if !ok {
return false
}
if _, ok := bvw.w.(*SliceWriter); ok {
bvw.buf = nil
}
bvw.w = nil
bvwp.pool.Put(bvw)
return true
}
// This is here so that during testing we can change it and not require
// allocating a 4GB slice.
var maxSize = math.MaxInt32
var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
type errMaxDocumentSizeExceeded struct {
size int64
}
func (mdse errMaxDocumentSizeExceeded) Error() string {
return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
}
type vwMode int
const (
_ vwMode = iota
vwTopLevel
vwDocument
vwArray
vwValue
vwElement
vwCodeWithScope
)
func (vm vwMode) String() string {
var str string
switch vm {
case vwTopLevel:
str = "TopLevel"
case vwDocument:
str = "DocumentMode"
case vwArray:
str = "ArrayMode"
case vwValue:
str = "ValueMode"
case vwElement:
str = "ElementMode"
case vwCodeWithScope:
str = "CodeWithScopeMode"
default:
str = "UnknownMode"
}
return str
}
type vwState struct {
mode mode
key string
arrkey int
start int32
}
type valueWriter struct {
w io.Writer
buf []byte
stack []vwState
frame int64
}
func (vw *valueWriter) advanceFrame() {
if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
length := len(vw.stack)
if length+1 >= cap(vw.stack) {
// double it
buf := make([]vwState, 2*cap(vw.stack)+1)
copy(buf, vw.stack)
vw.stack = buf
}
vw.stack = vw.stack[:length+1]
}
vw.frame++
}
func (vw *valueWriter) push(m mode) {
vw.advanceFrame()
// Clean the stack
vw.stack[vw.frame].mode = m
vw.stack[vw.frame].key = ""
vw.stack[vw.frame].arrkey = 0
vw.stack[vw.frame].start = 0
vw.stack[vw.frame].mode = m
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength()
}
}
func (vw *valueWriter) reserveLength() {
vw.stack[vw.frame].start = int32(len(vw.buf))
vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
}
func (vw *valueWriter) pop() {
switch vw.stack[vw.frame].mode {
case mElement, mValue:
vw.frame--
case mDocument, mArray, mCodeWithScope:
vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
// NewBSONValueWriter creates a ValueWriter that writes BSON to w.
//
// This ValueWriter will only write entire documents to the io.Writer and it
// will buffer the document as it is built.
func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
if w == nil {
return nil, errNilWriter
}
return newValueWriter(w), nil
}
func newValueWriter(w io.Writer) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.w = w
vw.stack = stack
return vw
}
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.stack = stack
vw.buf = buf
return vw
}
func (vw *valueWriter) reset(buf []byte) {
if vw.stack == nil {
vw.stack = make([]vwState, 1, 5)
}
vw.stack = vw.stack[:1]
vw.stack[0] = vwState{mode: mTopLevel}
vw.buf = buf
vw.frame = 0
vw.w = nil
}
func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vw.stack[vw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if vw.frame != 0 {
te.parent = vw.stack[vw.frame-1].mode
}
return te
}
func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
switch vw.stack[vw.frame].mode {
case mElement:
vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key)
case mValue:
// TODO: Do this with a cache of the first 1000 or so array keys.
vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return vw.invalidTransitionError(destination, callerName, modes)
}
return nil
}
func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
return err
}
vw.buf = append(vw.buf, b...)
vw.pop()
return nil
}
func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
return nil, err
}
vw.push(mArray)
return vw, nil
}
func (vw *valueWriter) WriteBinary(b []byte) error {
return vw.WriteBinaryWithSubtype(b, 0x00)
}
func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteBoolean(b bool) error {
if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
return err
}
vw.buf = bsoncore.AppendBoolean(vw.buf, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
// CodeWithScope is a different than other types because we need an extra
// frame on the stack. In the EndDocument code, we write the document
// length, pop, write the code with scope length, and pop. To simplify the
// pop code, we push a spacer frame that we'll always jump over.
vw.push(mCodeWithScope)
vw.buf = bsoncore.AppendString(vw.buf, code)
vw.push(mSpacer)
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
return err
}
vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDateTime(dt int64) error {
if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
return err
}
vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
return err
}
vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDouble(f float64) error {
if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
return err
}
vw.buf = bsoncore.AppendDouble(vw.buf, f)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt32(i32 int32) error {
if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt32(vw.buf, i32)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt64(i64 int64) error {
if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt64(vw.buf, i64)
vw.pop()
return nil
}
func (vw *valueWriter) WriteJavascript(code string) error {
if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
return err
}
vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
vw.pop()
return nil
}
func (vw *valueWriter) WriteMaxKey() error {
if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteMinKey() error {
if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteNull() error {
if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
return err
}
vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteRegex(pattern string, options string) error {
if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
return err
}
vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
vw.pop()
return nil
}
func (vw *valueWriter) WriteString(s string) error {
if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
return err
}
vw.buf = bsoncore.AppendString(vw.buf, s)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
if vw.stack[vw.frame].mode == mTopLevel {
vw.reserveLength()
return vw, nil
}
if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteSymbol(symbol string) error {
if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
return err
}
vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
vw.pop()
return nil
}
func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
return err
}
vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
vw.pop()
return nil
}
func (vw *valueWriter) WriteUndefined() error {
if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
}
vw.push(mElement)
vw.stack[vw.frame].key = key
return vw, nil
}
func (vw *valueWriter) WriteDocumentEnd() error {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
if vw.stack[vw.frame].mode == mTopLevel {
if vw.w != nil {
if sw, ok := vw.w.(*SliceWriter); ok {
*sw = vw.buf
} else {
_, err = vw.w.Write(vw.buf)
if err != nil {
return err
}
// reset buffer
vw.buf = vw.buf[:0]
}
}
}
vw.pop()
if vw.stack[vw.frame].mode == mCodeWithScope {
// We ignore the error here because of the gaurantee of writeLength.
// See the docs for writeLength for more info.
_ = vw.writeLength()
vw.pop()
}
return nil
}
func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
if vw.stack[vw.frame].mode != mArray {
return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
}
arrkey := vw.stack[vw.frame].arrkey
vw.stack[vw.frame].arrkey++
vw.push(mValue)
vw.stack[vw.frame].arrkey = arrkey
return vw, nil
}
func (vw *valueWriter) WriteArrayEnd() error {
if vw.stack[vw.frame].mode != mArray {
return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
vw.pop()
return nil
}
// NOTE: We assume that if we call writeLength more than once the same function
// within the same function without altering the vw.buf that this method will
// not return an error. If this changes ensure that the following methods are
// updated:
//
// - WriteDocumentEnd
func (vw *valueWriter) writeLength() error {
length := len(vw.buf)
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
length = length - int(vw.stack[vw.frame].start)
start := vw.stack[vw.frame].start
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
vw.buf[start+3] = byte(length >> 24)
return nil
}

View File

@@ -0,0 +1,96 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayWriter is the interface used to create a BSON or BSON adjacent array.
// Callers must ensure they call WriteArrayEnd when they have finished creating
// the array.
type ArrayWriter interface {
WriteArrayElement() (ValueWriter, error)
WriteArrayEnd() error
}
// DocumentWriter is the interface used to create a BSON or BSON adjacent
// document. Callers must ensure they call WriteDocumentEnd when they have
// finished creating the document.
type DocumentWriter interface {
WriteDocumentElement(string) (ValueWriter, error)
WriteDocumentEnd() error
}
// ValueWriter is the interface used to write BSON values. Implementations of
// this interface handle creating BSON or BSON adjacent representations of the
// values.
type ValueWriter interface {
WriteArray() (ArrayWriter, error)
WriteBinary(b []byte) error
WriteBinaryWithSubtype(b []byte, btype byte) error
WriteBoolean(bool) error
WriteCodeWithScope(code string) (DocumentWriter, error)
WriteDBPointer(ns string, oid primitive.ObjectID) error
WriteDateTime(dt int64) error
WriteDecimal128(primitive.Decimal128) error
WriteDouble(float64) error
WriteInt32(int32) error
WriteInt64(int64) error
WriteJavascript(code string) error
WriteMaxKey() error
WriteMinKey() error
WriteNull() error
WriteObjectID(primitive.ObjectID) error
WriteRegex(pattern, options string) error
WriteString(string) error
WriteDocument() (DocumentWriter, error)
WriteSymbol(symbol string) error
WriteTimestamp(t, i uint32) error
WriteUndefined() error
}
// BytesWriter is the interface used to write BSON bytes to a ValueWriter.
// This interface is meant to be a superset of ValueWriter, so that types that
// implement ValueWriter may also implement this interface.
type BytesWriter interface {
WriteValueBytes(t bsontype.Type, b []byte) error
}
// SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer.
type SliceWriter []byte
func (sw *SliceWriter) Write(p []byte) (int, error) {
written := len(p)
*sw = append(*sw, p...)
return written, nil
}
type writer []byte
func (w *writer) Write(p []byte) (int, error) {
index := len(*w)
return w.WriteAt(p, int64(index))
}
func (w *writer) WriteAt(p []byte, off int64) (int, error) {
newend := off + int64(len(p))
if newend < int64(len(*w)) {
newend = int64(len(*w))
}
if newend > int64(cap(*w)) {
buf := make([]byte, int64(2*cap(*w))+newend)
copy(buf, *w)
*w = buf
}
*w = []byte(*w)[:newend]
copy([]byte(*w)[off:], p)
return len(p), nil
}