mirror of
				https://github.com/go-gitea/gitea
				synced 2025-11-04 05:18:25 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			253 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			253 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// @author Couchbase <info@couchbase.com>
 | 
						|
// @copyright 2018 Couchbase, Inc.
 | 
						|
//
 | 
						|
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
// you may not use this file except in compliance with the License.
 | 
						|
// You may obtain a copy of the License at
 | 
						|
//
 | 
						|
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
//
 | 
						|
// Unless required by applicable law or agreed to in writing, software
 | 
						|
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
// See the License for the specific language governing permissions and
 | 
						|
// limitations under the License.
 | 
						|
 | 
						|
// Package scramsha provides implementation of client side SCRAM-SHA
 | 
						|
// via Http according to https://tools.ietf.org/html/rfc7804
 | 
						|
package scramsha
 | 
						|
 | 
						|
import (
 | 
						|
	"encoding/base64"
 | 
						|
	"github.com/pkg/errors"
 | 
						|
	"io"
 | 
						|
	"io/ioutil"
 | 
						|
	"net/http"
 | 
						|
	"strings"
 | 
						|
)
 | 
						|
 | 
						|
// consts used to parse scramsha response from target
 | 
						|
const (
 | 
						|
	WWWAuthenticate    = "WWW-Authenticate"
 | 
						|
	AuthenticationInfo = "Authentication-Info"
 | 
						|
	Authorization      = "Authorization"
 | 
						|
	DataPrefix         = "data="
 | 
						|
	SidPrefix          = "sid="
 | 
						|
)
 | 
						|
 | 
						|
// Request provides implementation of http request that can be retried
 | 
						|
type Request struct {
 | 
						|
	body io.ReadSeeker
 | 
						|
 | 
						|
	// Embed an HTTP request directly. This makes a *Request act exactly
 | 
						|
	// like an *http.Request so that all meta methods are supported.
 | 
						|
	*http.Request
 | 
						|
}
 | 
						|
 | 
						|
type lenReader interface {
 | 
						|
	Len() int
 | 
						|
}
 | 
						|
 | 
						|
// NewRequest creates http request that can be retried
 | 
						|
func NewRequest(method, url string, body io.ReadSeeker) (*Request, error) {
 | 
						|
	// Wrap the body in a noop ReadCloser if non-nil. This prevents the
 | 
						|
	// reader from being closed by the HTTP client.
 | 
						|
	var rcBody io.ReadCloser
 | 
						|
	if body != nil {
 | 
						|
		rcBody = ioutil.NopCloser(body)
 | 
						|
	}
 | 
						|
 | 
						|
	// Make the request with the noop-closer for the body.
 | 
						|
	httpReq, err := http.NewRequest(method, url, rcBody)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Check if we can set the Content-Length automatically.
 | 
						|
	if lr, ok := body.(lenReader); ok {
 | 
						|
		httpReq.ContentLength = int64(lr.Len())
 | 
						|
	}
 | 
						|
 | 
						|
	return &Request{body, httpReq}, nil
 | 
						|
}
 | 
						|
 | 
						|
func encode(str string) string {
 | 
						|
	return base64.StdEncoding.EncodeToString([]byte(str))
 | 
						|
}
 | 
						|
 | 
						|
func decode(str string) (string, error) {
 | 
						|
	bytes, err := base64.StdEncoding.DecodeString(str)
 | 
						|
	if err != nil {
 | 
						|
		return "", errors.Errorf("Cannot base64 decode %s",
 | 
						|
			str)
 | 
						|
	}
 | 
						|
	return string(bytes), err
 | 
						|
}
 | 
						|
 | 
						|
func trimPrefix(s, prefix string) (string, error) {
 | 
						|
	l := len(s)
 | 
						|
	trimmed := strings.TrimPrefix(s, prefix)
 | 
						|
	if l == len(trimmed) {
 | 
						|
		return trimmed, errors.Errorf("Prefix %s not found in %s",
 | 
						|
			prefix, s)
 | 
						|
	}
 | 
						|
	return trimmed, nil
 | 
						|
}
 | 
						|
 | 
						|
func drainBody(resp *http.Response) {
 | 
						|
	defer resp.Body.Close()
 | 
						|
	io.Copy(ioutil.Discard, resp.Body)
 | 
						|
}
 | 
						|
 | 
						|
// DoScramSha performs SCRAM-SHA handshake via Http
 | 
						|
func DoScramSha(req *Request,
 | 
						|
	username string,
 | 
						|
	password string,
 | 
						|
	client *http.Client) (*http.Response, error) {
 | 
						|
 | 
						|
	method := "SCRAM-SHA-512"
 | 
						|
	s, err := NewScramSha("SCRAM-SHA512")
 | 
						|
	if err != nil {
 | 
						|
		return nil, errors.Wrap(err,
 | 
						|
			"Unable to initialize SCRAM-SHA handler")
 | 
						|
	}
 | 
						|
 | 
						|
	message, err := s.GetStartRequest(username)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	encodedMessage := method + " " + DataPrefix + encode(message)
 | 
						|
 | 
						|
	req.Header.Set(Authorization, encodedMessage)
 | 
						|
 | 
						|
	res, err := client.Do(req.Request)
 | 
						|
	if err != nil {
 | 
						|
		return nil, errors.Wrap(err, "Problem sending SCRAM-SHA start"+
 | 
						|
			"request")
 | 
						|
	}
 | 
						|
 | 
						|
	if res.StatusCode != http.StatusUnauthorized {
 | 
						|
		return res, nil
 | 
						|
	}
 | 
						|
 | 
						|
	authHeader := res.Header.Get(WWWAuthenticate)
 | 
						|
	if authHeader == "" {
 | 
						|
		drainBody(res)
 | 
						|
		return nil, errors.Errorf("Header %s is not populated in "+
 | 
						|
			"SCRAM-SHA start response", WWWAuthenticate)
 | 
						|
	}
 | 
						|
 | 
						|
	authHeader, err = trimPrefix(authHeader, method+" ")
 | 
						|
	if err != nil {
 | 
						|
		if strings.HasPrefix(authHeader, "Basic ") {
 | 
						|
			// user not found
 | 
						|
			return res, nil
 | 
						|
		}
 | 
						|
		drainBody(res)
 | 
						|
		return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
 | 
						|
			"start response %s", authHeader)
 | 
						|
	}
 | 
						|
 | 
						|
	drainBody(res)
 | 
						|
 | 
						|
	sid, response, err := parseSidAndData(authHeader)
 | 
						|
	if err != nil {
 | 
						|
		return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
 | 
						|
			"start response %s", authHeader)
 | 
						|
	}
 | 
						|
 | 
						|
	err = s.HandleStartResponse(response)
 | 
						|
	if err != nil {
 | 
						|
		return nil, errors.Wrapf(err, "Error parsing SCRAM-SHA start "+
 | 
						|
			"response %s", response)
 | 
						|
	}
 | 
						|
 | 
						|
	message = s.GetFinalRequest(password)
 | 
						|
	encodedMessage = method + " " + SidPrefix + sid + "," + DataPrefix +
 | 
						|
		encode(message)
 | 
						|
 | 
						|
	req.Header.Set(Authorization, encodedMessage)
 | 
						|
 | 
						|
	// rewind request body so it can be resent again
 | 
						|
	if req.body != nil {
 | 
						|
		if _, err = req.body.Seek(0, 0); err != nil {
 | 
						|
			return nil, errors.Errorf("Failed to seek body: %v",
 | 
						|
				err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	res, err = client.Do(req.Request)
 | 
						|
	if err != nil {
 | 
						|
		return nil, errors.Wrap(err, "Problem sending SCRAM-SHA final"+
 | 
						|
			"request")
 | 
						|
	}
 | 
						|
 | 
						|
	if res.StatusCode == http.StatusUnauthorized {
 | 
						|
		// TODO retrieve and return error
 | 
						|
		return res, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if res.StatusCode >= http.StatusInternalServerError {
 | 
						|
		// in this case we cannot expect server to set headers properly
 | 
						|
		return res, nil
 | 
						|
	}
 | 
						|
 | 
						|
	authHeader = res.Header.Get(AuthenticationInfo)
 | 
						|
	if authHeader == "" {
 | 
						|
		drainBody(res)
 | 
						|
		return nil, errors.Errorf("Header %s is not populated in "+
 | 
						|
			"SCRAM-SHA final response", AuthenticationInfo)
 | 
						|
	}
 | 
						|
 | 
						|
	finalSid, response, err := parseSidAndData(authHeader)
 | 
						|
	if err != nil {
 | 
						|
		drainBody(res)
 | 
						|
		return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
 | 
						|
			"final response %s", authHeader)
 | 
						|
	}
 | 
						|
 | 
						|
	if finalSid != sid {
 | 
						|
		drainBody(res)
 | 
						|
		return nil, errors.Errorf("Sid %s returned by server "+
 | 
						|
			"doesn't match the original sid %s", finalSid, sid)
 | 
						|
	}
 | 
						|
 | 
						|
	err = s.HandleFinalResponse(response)
 | 
						|
	if err != nil {
 | 
						|
		drainBody(res)
 | 
						|
		return nil, errors.Wrapf(err,
 | 
						|
			"Error handling SCRAM-SHA final server response %s",
 | 
						|
			response)
 | 
						|
	}
 | 
						|
	return res, nil
 | 
						|
}
 | 
						|
 | 
						|
func parseSidAndData(authHeader string) (string, string, error) {
 | 
						|
	sidIndex := strings.Index(authHeader, SidPrefix)
 | 
						|
	if sidIndex < 0 {
 | 
						|
		return "", "", errors.Errorf("Cannot find %s in %s",
 | 
						|
			SidPrefix, authHeader)
 | 
						|
	}
 | 
						|
 | 
						|
	sidEndIndex := strings.Index(authHeader, ",")
 | 
						|
	if sidEndIndex < 0 {
 | 
						|
		return "", "", errors.Errorf("Cannot find ',' in %s",
 | 
						|
			authHeader)
 | 
						|
	}
 | 
						|
 | 
						|
	sid := authHeader[sidIndex+len(SidPrefix) : sidEndIndex]
 | 
						|
 | 
						|
	dataIndex := strings.Index(authHeader, DataPrefix)
 | 
						|
	if dataIndex < 0 {
 | 
						|
		return "", "", errors.Errorf("Cannot find %s in %s",
 | 
						|
			DataPrefix, authHeader)
 | 
						|
	}
 | 
						|
 | 
						|
	data, err := decode(authHeader[dataIndex+len(DataPrefix):])
 | 
						|
	if err != nil {
 | 
						|
		return "", "", err
 | 
						|
	}
 | 
						|
	return sid, data, nil
 | 
						|
}
 |