// Package retryablehttp provides a familiar HTTP client interface with // automatic retries and exponential backoff. It is a thin wrapper over the // standard net/http client library and exposes nearly the same public API. // This makes retryablehttp very easy to drop into existing programs. // // retryablehttp performs automatic retries under certain conditions. Mainly, if // an error is returned by the client (connection errors etc), or if a 500-range // response is received, then a retry is invoked. Otherwise, the response is // returned and left to the caller to interpret. // // Requests which take a request body should provide a non-nil function // parameter. The best choice is to provide either a function satisfying // ReaderFunc which provides multiple io.Readers in an efficient manner, a // *bytes.Buffer (the underlying raw byte slice will be used) or a raw byte // slice. As it is a reference type, and we will wrap it as needed by readers, // we can efficiently re-use the request body without needing to copy it. If an // io.Reader (such as a *bytes.Reader) is provided, the full body will be read // prior to the first request, and will be efficiently re-used for any retries. // ReadSeeker can be used, but some users have observed occasional data races // between the net/http library and the Seek functionality of some // implementations of ReadSeeker, so should be avoided if possible. package retryablehttp import ( "bytes" "context" "crypto/x509" "fmt" "io" "io/ioutil" "log" "math" "math/rand" "net/http" "net/url" "os" "regexp" "strconv" "strings" "sync" "time" cleanhttp "github.com/hashicorp/go-cleanhttp" ) var ( // Default retry configuration defaultRetryWaitMin = 1 * time.Second defaultRetryWaitMax = 30 * time.Second defaultRetryMax = 4 // defaultLogger is the logger provided with defaultClient defaultLogger = log.New(os.Stderr, "", log.LstdFlags) // defaultClient is used for performing requests without explicitly making // a new client. It is purposely private to avoid modifications. defaultClient = NewClient() // We need to consume response bodies to maintain http connections, but // limit the size we consume to respReadLimit. respReadLimit = int64(4096) // A regular expression to match the error returned by net/http when the // configured number of redirects is exhausted. This error isn't typed // specifically so we resort to matching on the error string. redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`) // A regular expression to match the error returned by net/http when the // scheme specified in the URL is invalid. This error isn't typed // specifically so we resort to matching on the error string. schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) ) // ReaderFunc is the type of function that can be given natively to NewRequest type ReaderFunc func() (io.Reader, error) // LenReader is an interface implemented by many in-memory io.Reader's. Used // for automatically sending the right Content-Length header when possible. type LenReader interface { Len() int } // Request wraps the metadata needed to create HTTP requests. type Request struct { // body is a seekable reader over the request body payload. This is // used to rewind the request data in between retries. body ReaderFunc // Embed an HTTP request directly. This makes a *Request act exactly // like an *http.Request so that all meta methods are supported. *http.Request } // WithContext returns wrapped Request with a shallow copy of underlying *http.Request // with its context changed to ctx. The provided ctx must be non-nil. func (r *Request) WithContext(ctx context.Context) *Request { r.Request = r.Request.WithContext(ctx) return r } // BodyBytes allows accessing the request body. It is an analogue to // http.Request's Body variable, but it returns a copy of the underlying data // rather than consuming it. // // This function is not thread-safe; do not call it at the same time as another // call, or at the same time this request is being used with Client.Do. func (r *Request) BodyBytes() ([]byte, error) { if r.body == nil { return nil, nil } body, err := r.body() if err != nil { return nil, err } buf := new(bytes.Buffer) _, err = buf.ReadFrom(body) if err != nil { return nil, err } return buf.Bytes(), nil } // SetBody allows setting the request body. // // It is useful if a new body needs to be set without constructing a new Request. func (r *Request) SetBody(rawBody interface{}) error { bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) if err != nil { return err } r.body = bodyReader r.ContentLength = contentLength return nil } // WriteTo allows copying the request body into a writer. // // It writes data to w until there's no more data to write or // when an error occurs. The return int64 value is the number of bytes // written. Any error encountered during the write is also returned. // The signature matches io.WriterTo interface. func (r *Request) WriteTo(w io.Writer) (int64, error) { body, err := r.body() if err != nil { return 0, err } if c, ok := body.(io.Closer); ok { defer c.Close() } return io.Copy(w, body) } func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) { var bodyReader ReaderFunc var contentLength int64 switch body := rawBody.(type) { // If they gave us a function already, great! Use it. case ReaderFunc: bodyReader = body tmp, err := body() if err != nil { return nil, 0, err } if lr, ok := tmp.(LenReader); ok { contentLength = int64(lr.Len()) } if c, ok := tmp.(io.Closer); ok { c.Close() } case func() (io.Reader, error): bodyReader = body tmp, err := body() if err != nil { return nil, 0, err } if lr, ok := tmp.(LenReader); ok { contentLength = int64(lr.Len()) } if c, ok := tmp.(io.Closer); ok { c.Close() } // If a regular byte slice, we can read it over and over via new // readers case []byte: buf := body bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf), nil } contentLength = int64(len(buf)) // If a bytes.Buffer we can read the underlying byte slice over and // over case *bytes.Buffer: buf := body bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf.Bytes()), nil } contentLength = int64(buf.Len()) // We prioritize *bytes.Reader here because we don't really want to // deal with it seeking so want it to match here instead of the // io.ReadSeeker case. case *bytes.Reader: buf, err := ioutil.ReadAll(body) if err != nil { return nil, 0, err } bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf), nil } contentLength = int64(len(buf)) // Compat case case io.ReadSeeker: raw := body bodyReader = func() (io.Reader, error) { _, err := raw.Seek(0, 0) return ioutil.NopCloser(raw), err } if lr, ok := raw.(LenReader); ok { contentLength = int64(lr.Len()) } // Read all in so we can reset case io.Reader: buf, err := ioutil.ReadAll(body) if err != nil { return nil, 0, err } bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf), nil } contentLength = int64(len(buf)) // No body provided, nothing to do case nil: // Unrecognized type default: return nil, 0, fmt.Errorf("cannot handle type %T", rawBody) } return bodyReader, contentLength, nil } // FromRequest wraps an http.Request in a retryablehttp.Request func FromRequest(r *http.Request) (*Request, error) { bodyReader, _, err := getBodyReaderAndContentLength(r.Body) if err != nil { return nil, err } // Could assert contentLength == r.ContentLength return &Request{bodyReader, r}, nil } // NewRequest creates a new wrapped request. func NewRequest(method, url string, rawBody interface{}) (*Request, error) { bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) if err != nil { return nil, err } httpReq, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } httpReq.ContentLength = contentLength return &Request{bodyReader, httpReq}, nil } // Logger interface allows to use other loggers than // standard log.Logger. type Logger interface { Printf(string, ...interface{}) } // LeveledLogger is an interface that can be implemented by any logger or a // logger wrapper to provide leveled logging. The methods accept a message // string and a variadic number of key-value pairs. For log.Printf style // formatting where message string contains a format specifier, use Logger // interface. type LeveledLogger interface { Error(msg string, keysAndValues ...interface{}) Info(msg string, keysAndValues ...interface{}) Debug(msg string, keysAndValues ...interface{}) Warn(msg string, keysAndValues ...interface{}) } // hookLogger adapts an LeveledLogger to Logger for use by the existing hook functions // without changing the API. type hookLogger struct { LeveledLogger } func (h hookLogger) Printf(s string, args ...interface{}) { h.Info(fmt.Sprintf(s, args...)) } // RequestLogHook allows a function to run before each retry. The HTTP // request which will be made, and the retry number (0 for the initial // request) are available to users. The internal logger is exposed to // consumers. type RequestLogHook func(Logger, *http.Request, int) // ResponseLogHook is like RequestLogHook, but allows running a function // on each HTTP response. This function will be invoked at the end of // every HTTP request executed, regardless of whether a subsequent retry // needs to be performed or not. If the response body is read or closed // from this method, this will affect the response returned from Do(). type ResponseLogHook func(Logger, *http.Response) // CheckRetry specifies a policy for handling retries. It is called // following each request with the response and error values returned by // the http.Client. If CheckRetry returns false, the Client stops retrying // and returns the response to the caller. If CheckRetry returns an error, // that error value is returned in lieu of the error from the request. The // Client will close any response body when retrying, but if the retry is // aborted it is up to the CheckRetry callback to properly close any // response body before returning. type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error) // Backoff specifies a policy for how long to wait between retries. // It is called after a failing request to determine the amount of time // that should pass before trying again. type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration // ErrorHandler is called if retries are expired, containing the last status // from the http library. If not specified, default behavior for the library is // to close the body and return an error indicating how many tries were // attempted. If overriding this, be sure to close the body if needed. type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error) // Client is used to make HTTP requests. It adds additional functionality // like automatic retries to tolerate minor outages. type Client struct { HTTPClient *http.Client // Internal HTTP client. Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger RetryWaitMin time.Duration // Minimum time to wait RetryWaitMax time.Duration // Maximum time to wait RetryMax int // Maximum number of retries // RequestLogHook allows a user-supplied function to be called // before each retry. RequestLogHook RequestLogHook // ResponseLogHook allows a user-supplied function to be called // with the response from each HTTP request executed. ResponseLogHook ResponseLogHook // CheckRetry specifies the policy for handling retries, and is called // after each request. The default policy is DefaultRetryPolicy. CheckRetry CheckRetry // Backoff specifies the policy for how long to wait between retries Backoff Backoff // ErrorHandler specifies the custom error handler to use, if any ErrorHandler ErrorHandler loggerInit sync.Once clientInit sync.Once } // NewClient creates a new Client with default settings. func NewClient() *Client { return &Client{ HTTPClient: cleanhttp.DefaultPooledClient(), Logger: defaultLogger, RetryWaitMin: defaultRetryWaitMin, RetryWaitMax: defaultRetryWaitMax, RetryMax: defaultRetryMax, CheckRetry: DefaultRetryPolicy, Backoff: DefaultBackoff, } } func (c *Client) logger() interface{} { c.loggerInit.Do(func() { if c.Logger == nil { return } switch c.Logger.(type) { case Logger, LeveledLogger: // ok default: // This should happen in dev when they are setting Logger and work on code, not in prod. panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger)) } }) return c.Logger } // DefaultRetryPolicy provides a default callback for Client.CheckRetry, which // will retry on connection errors and server errors. func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { // do not retry on context.Canceled or context.DeadlineExceeded if ctx.Err() != nil { return false, ctx.Err() } // don't propagate other errors shouldRetry, _ := baseRetryPolicy(resp, err) return shouldRetry, nil } // ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it // propagates errors back instead of returning nil. This allows you to inspect // why it decided to retry or not. func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { // do not retry on context.Canceled or context.DeadlineExceeded if ctx.Err() != nil { return false, ctx.Err() } return baseRetryPolicy(resp, err) } func baseRetryPolicy(resp *http.Response, err error) (bool, error) { if err != nil { if v, ok := err.(*url.Error); ok { // Don't retry if the error was due to too many redirects. if redirectsErrorRe.MatchString(v.Error()) { return false, v } // Don't retry if the error was due to an invalid protocol scheme. if schemeErrorRe.MatchString(v.Error()) { return false, v } // Don't retry if the error was due to TLS cert verification failure. if _, ok := v.Err.(x509.UnknownAuthorityError); ok { return false, v } } // The error is likely recoverable so retry. return true, nil } // 429 Too Many Requests is recoverable. Sometimes the server puts // a Retry-After response header to indicate when the server is // available to start processing request from client. if resp.StatusCode == http.StatusTooManyRequests { return true, nil } // Check the response code. We retry on 500-range responses to allow // the server time to recover, as 500's are typically not permanent // errors and may relate to outages on the server side. This will catch // invalid response codes as well, like 0 and 999. if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != 501) { return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) } return false, nil } // DefaultBackoff provides a default callback for Client.Backoff which // will perform exponential backoff based on the attempt number and limited // by the provided minimum and maximum durations. // // It also tries to parse Retry-After response header when a http.StatusTooManyRequests // (HTTP Code 429) is found in the resp parameter. Hence it will return the number of // seconds the server states it may be ready to process more requests from this client. func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { if resp != nil { if resp.StatusCode == http.StatusTooManyRequests { if s, ok := resp.Header["Retry-After"]; ok { if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { return time.Second * time.Duration(sleep) } } } } mult := math.Pow(2, float64(attemptNum)) * float64(min) sleep := time.Duration(mult) if float64(sleep) != mult || sleep > max { sleep = max } return sleep } // LinearJitterBackoff provides a callback for Client.Backoff which will // perform linear backoff based on the attempt number and with jitter to // prevent a thundering herd. // // min and max here are *not* absolute values. The number to be multiplied by // the attempt number will be chosen at random from between them, thus they are // bounding the jitter. // // For instance: // * To get strictly linear backoff of one second increasing each retry, set // both to one second (1s, 2s, 3s, 4s, ...) // * To get a small amount of jitter centered around one second increasing each // retry, set to around one second, such as a min of 800ms and max of 1200ms // (892ms, 2102ms, 2945ms, 4312ms, ...) // * To get extreme jitter, set to a very wide spread, such as a min of 100ms // and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...) func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { // attemptNum always starts at zero but we want to start at 1 for multiplication attemptNum++ if max <= min { // Unclear what to do here, or they are the same, so return min * // attemptNum return min * time.Duration(attemptNum) } // Seed rand; doing this every time is fine rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // Pick a random number that lies somewhere between the min and max and // multiply by the attemptNum. attemptNum starts at zero so we always // increment here. We first get a random percentage, then apply that to the // difference between min and max, and add to min. jitter := rand.Float64() * float64(max-min) jitterMin := int64(jitter) + int64(min) return time.Duration(jitterMin * int64(attemptNum)) } // PassthroughErrorHandler is an ErrorHandler that directly passes through the // values from the net/http library for the final request. The body is not // closed. func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) { return resp, err } // Do wraps calling an HTTP method with retries. func (c *Client) Do(req *Request) (*http.Response, error) { c.clientInit.Do(func() { if c.HTTPClient == nil { c.HTTPClient = cleanhttp.DefaultPooledClient() } }) logger := c.logger() if logger != nil { switch v := logger.(type) { case LeveledLogger: v.Debug("performing request", "method", req.Method, "url", req.URL) case Logger: v.Printf("[DEBUG] %s %s", req.Method, req.URL) } } var resp *http.Response var attempt int var shouldRetry bool var doErr, checkErr error for i := 0; ; i++ { attempt++ var code int // HTTP response code // Always rewind the request body when non-nil. if req.body != nil { body, err := req.body() if err != nil { c.HTTPClient.CloseIdleConnections() return resp, err } if c, ok := body.(io.ReadCloser); ok { req.Body = c } else { req.Body = ioutil.NopCloser(body) } } if c.RequestLogHook != nil { switch v := logger.(type) { case LeveledLogger: c.RequestLogHook(hookLogger{v}, req.Request, i) case Logger: c.RequestLogHook(v, req.Request, i) default: c.RequestLogHook(nil, req.Request, i) } } // Attempt the request resp, doErr = c.HTTPClient.Do(req.Request) if resp != nil { code = resp.StatusCode } // Check if we should continue with retries. shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) if doErr != nil { switch v := logger.(type) { case LeveledLogger: v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL) case Logger: v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) } } else { // Call this here to maintain the behavior of logging all requests, // even if CheckRetry signals to stop. if c.ResponseLogHook != nil { // Call the response logger function if provided. switch v := logger.(type) { case LeveledLogger: c.ResponseLogHook(hookLogger{v}, resp) case Logger: c.ResponseLogHook(v, resp) default: c.ResponseLogHook(nil, resp) } } } if !shouldRetry { break } // We do this before drainBody because there's no need for the I/O if // we're breaking out remain := c.RetryMax - i if remain <= 0 { break } // We're going to retry, consume any response to reuse the connection. if doErr == nil { c.drainBody(resp.Body) } wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) desc := fmt.Sprintf("%s %s", req.Method, req.URL) if code > 0 { desc = fmt.Sprintf("%s (status: %d)", desc, code) } if logger != nil { switch v := logger.(type) { case LeveledLogger: v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) case Logger: v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) } } select { case <-req.Context().Done(): c.HTTPClient.CloseIdleConnections() return nil, req.Context().Err() case <-time.After(wait): } // Make shallow copy of http Request so that we can modify its body // without racing against the closeBody call in persistConn.writeLoop. httpreq := *req.Request req.Request = &httpreq } // this is the closest we have to success criteria if doErr == nil && checkErr == nil && !shouldRetry { return resp, nil } defer c.HTTPClient.CloseIdleConnections() err := doErr if checkErr != nil { err = checkErr } if c.ErrorHandler != nil { return c.ErrorHandler(resp, err, attempt) } // By default, we close the response body and return an error without // returning the response if resp != nil { c.drainBody(resp.Body) } // this means CheckRetry thought the request was a failure, but didn't // communicate why if err == nil { return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", req.Method, req.URL, attempt) } return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", req.Method, req.URL, attempt, err) } // Try to read the response body so we can reuse this connection. func (c *Client) drainBody(body io.ReadCloser) { defer body.Close() _, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit)) if err != nil { if c.logger() != nil { switch v := c.logger().(type) { case LeveledLogger: v.Error("error reading response body", "error", err) case Logger: v.Printf("[ERR] error reading response body: %v", err) } } } } // Get is a shortcut for doing a GET request without making a new client. func Get(url string) (*http.Response, error) { return defaultClient.Get(url) } // Get is a convenience helper for doing simple GET requests. func (c *Client) Get(url string) (*http.Response, error) { req, err := NewRequest("GET", url, nil) if err != nil { return nil, err } return c.Do(req) } // Head is a shortcut for doing a HEAD request without making a new client. func Head(url string) (*http.Response, error) { return defaultClient.Head(url) } // Head is a convenience method for doing simple HEAD requests. func (c *Client) Head(url string) (*http.Response, error) { req, err := NewRequest("HEAD", url, nil) if err != nil { return nil, err } return c.Do(req) } // Post is a shortcut for doing a POST request without making a new client. func Post(url, bodyType string, body interface{}) (*http.Response, error) { return defaultClient.Post(url, bodyType, body) } // Post is a convenience method for doing simple POST requests. func (c *Client) Post(url, bodyType string, body interface{}) (*http.Response, error) { req, err := NewRequest("POST", url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", bodyType) return c.Do(req) } // PostForm is a shortcut to perform a POST with form data without creating // a new client. func PostForm(url string, data url.Values) (*http.Response, error) { return defaultClient.PostForm(url, data) } // PostForm is a convenience method for doing simple POST operations using // pre-filled url.Values form data. func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } // StandardClient returns a stdlib *http.Client with a custom Transport, which // shims in a *retryablehttp.Client for added retries. func (c *Client) StandardClient() *http.Client { return &http.Client{ Transport: &RoundTripper{Client: c}, } }