From 6073e2f1bbf83d6f92b1ff696f75d7c8a4520ac8 Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Mon, 20 Jan 2025 14:25:17 +0800 Subject: [PATCH] Refactor response writer & access logger (#33323) And add comments & tests --- routers/common/middleware.go | 12 ++-- services/context/access_log.go | 103 ++++++++++++++++------------ services/context/access_log_test.go | 75 ++++++++++++++++++++ services/context/response.go | 37 +++++----- 4 files changed, 159 insertions(+), 68 deletions(-) create mode 100644 services/context/access_log_test.go diff --git a/routers/common/middleware.go b/routers/common/middleware.go index 12b0c67b01..de59b78583 100644 --- a/routers/common/middleware.go +++ b/routers/common/middleware.go @@ -43,14 +43,18 @@ func ProtocolMiddlewares() (handlers []any) { func RequestContextHandler() func(h http.Handler) http.Handler { return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - profDesc := fmt.Sprintf("%s: %s", req.Method, req.RequestURI) + return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) { + // this response writer might not be the same as the one in context.Base.Resp + // because there might be a "gzip writer" in the middle, so the "written size" here is the compressed size + respWriter := context.WrapResponseWriter(respOrig) + + profDesc := fmt.Sprintf("HTTP: %s %s", req.Method, req.RequestURI) ctx, finished := reqctx.NewRequestContext(req.Context(), profDesc) defer finished() defer func() { if err := recover(); err != nil { - RenderPanicErrorPage(resp, req, err) // it should never panic + RenderPanicErrorPage(respWriter, req, err) // it should never panic } }() @@ -62,7 +66,7 @@ func RequestContextHandler() func(h http.Handler) http.Handler { _ = req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory } }) - next.ServeHTTP(context.WrapResponseWriter(resp), req) + next.ServeHTTP(respWriter, req) }) } } diff --git a/services/context/access_log.go b/services/context/access_log.go index 0926748ac5..1985aae118 100644 --- a/services/context/access_log.go +++ b/services/context/access_log.go @@ -18,13 +18,14 @@ import ( "code.gitea.io/gitea/modules/web/middleware" ) -type routerLoggerOptions struct { - req *http.Request +type accessLoggerTmplData struct { Identity *string Start *time.Time - ResponseWriter http.ResponseWriter - Ctx map[string]any - RequestID *string + ResponseWriter struct { + Status, Size int + } + Ctx map[string]any + RequestID *string } const keyOfRequestIDInTemplate = ".RequestID" @@ -51,51 +52,65 @@ func parseRequestIDFromRequestHeader(req *http.Request) string { return requestID } +type accessLogRecorder struct { + logger log.BaseLogger + logTemplate *template.Template + needRequestID bool +} + +func (lr *accessLogRecorder) record(start time.Time, respWriter ResponseWriter, req *http.Request) { + var requestID string + if lr.needRequestID { + requestID = parseRequestIDFromRequestHeader(req) + } + + reqHost, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + reqHost = req.RemoteAddr + } + + identity := "-" + data := middleware.GetContextData(req.Context()) + if signedUser, ok := data[middleware.ContextDataKeySignedUser].(*user_model.User); ok { + identity = signedUser.Name + } + buf := bytes.NewBuffer([]byte{}) + tmplData := accessLoggerTmplData{ + Identity: &identity, + Start: &start, + Ctx: map[string]any{ + "RemoteAddr": req.RemoteAddr, + "RemoteHost": reqHost, + "Req": req, + }, + RequestID: &requestID, + } + tmplData.ResponseWriter.Status = respWriter.Status() + tmplData.ResponseWriter.Size = respWriter.WrittenSize() + err = lr.logTemplate.Execute(buf, tmplData) + if err != nil { + log.Error("Could not execute access logger template: %v", err.Error()) + } + + lr.logger.Log(1, log.INFO, "%s", buf.String()) +} + +func newAccessLogRecorder() *accessLogRecorder { + return &accessLogRecorder{ + logger: log.GetLogger("access"), + logTemplate: template.Must(template.New("log").Parse(setting.Log.AccessLogTemplate)), + needRequestID: len(setting.Log.RequestIDHeaders) > 0 && strings.Contains(setting.Log.AccessLogTemplate, keyOfRequestIDInTemplate), + } +} + // AccessLogger returns a middleware to log access logger func AccessLogger() func(http.Handler) http.Handler { - logger := log.GetLogger("access") - needRequestID := len(setting.Log.RequestIDHeaders) > 0 && strings.Contains(setting.Log.AccessLogTemplate, keyOfRequestIDInTemplate) - logTemplate, _ := template.New("log").Parse(setting.Log.AccessLogTemplate) + recorder := newAccessLogRecorder() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { start := time.Now() - - var requestID string - if needRequestID { - requestID = parseRequestIDFromRequestHeader(req) - } - - reqHost, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - reqHost = req.RemoteAddr - } - next.ServeHTTP(w, req) - rw := w.(ResponseWriter) - - identity := "-" - data := middleware.GetContextData(req.Context()) - if signedUser, ok := data[middleware.ContextDataKeySignedUser].(*user_model.User); ok { - identity = signedUser.Name - } - buf := bytes.NewBuffer([]byte{}) - err = logTemplate.Execute(buf, routerLoggerOptions{ - req: req, - Identity: &identity, - Start: &start, - ResponseWriter: rw, - Ctx: map[string]any{ - "RemoteAddr": req.RemoteAddr, - "RemoteHost": reqHost, - "Req": req, - }, - RequestID: &requestID, - }) - if err != nil { - log.Error("Could not execute access logger template: %v", err.Error()) - } - - logger.Info("%s", buf.String()) + recorder.record(start, w.(ResponseWriter), req) }) } } diff --git a/services/context/access_log_test.go b/services/context/access_log_test.go new file mode 100644 index 0000000000..9aab918ae6 --- /dev/null +++ b/services/context/access_log_test.go @@ -0,0 +1,75 @@ +// Copyright 2025 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package context + +import ( + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" +) + +type testAccessLoggerMock struct { + logs []string +} + +func (t *testAccessLoggerMock) Log(skip int, level log.Level, format string, v ...any) { + t.logs = append(t.logs, fmt.Sprintf(format, v...)) +} + +func (t *testAccessLoggerMock) GetLevel() log.Level { + return log.INFO +} + +type testAccessLoggerResponseWriterMock struct{} + +func (t testAccessLoggerResponseWriterMock) Header() http.Header { + return nil +} + +func (t testAccessLoggerResponseWriterMock) Before(f func(ResponseWriter)) {} + +func (t testAccessLoggerResponseWriterMock) WriteHeader(statusCode int) {} + +func (t testAccessLoggerResponseWriterMock) Write(bytes []byte) (int, error) { + return 0, nil +} + +func (t testAccessLoggerResponseWriterMock) Flush() {} + +func (t testAccessLoggerResponseWriterMock) WrittenStatus() int { + return http.StatusOK +} + +func (t testAccessLoggerResponseWriterMock) Status() int { + return t.WrittenStatus() +} + +func (t testAccessLoggerResponseWriterMock) WrittenSize() int { + return 123123 +} + +func TestAccessLogger(t *testing.T) { + setting.Log.AccessLogTemplate = `{{.Ctx.RemoteHost}} - {{.Identity}} {{.Start.Format "[02/Jan/2006:15:04:05 -0700]" }} "{{.Ctx.Req.Method}} {{.Ctx.Req.URL.RequestURI}} {{.Ctx.Req.Proto}}" {{.ResponseWriter.Status}} {{.ResponseWriter.Size}} "{{.Ctx.Req.Referer}}" "{{.Ctx.Req.UserAgent}}"` + recorder := newAccessLogRecorder() + mockLogger := &testAccessLoggerMock{} + recorder.logger = mockLogger + req := &http.Request{ + RemoteAddr: "remote-addr", + Method: "GET", + Proto: "https", + URL: &url.URL{Path: "/path"}, + } + req.Header = http.Header{} + req.Header.Add("Referer", "referer") + req.Header.Add("User-Agent", "user-agent") + recorder.record(time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), &testAccessLoggerResponseWriterMock{}, req) + assert.Equal(t, []string{`remote-addr - - [02/Jan/2000:03:04:05 +0000] "GET /path https" 200 123123 "referer" "user-agent"`}, mockLogger.logs) +} diff --git a/services/context/response.go b/services/context/response.go index 2f271f211b..3e557a112e 100644 --- a/services/context/response.go +++ b/services/context/response.go @@ -15,27 +15,26 @@ type ResponseWriter interface { http.Flusher web_types.ResponseStatusProvider - Before(func(ResponseWriter)) - - Status() int // used by access logger template - Size() int // used by access logger template + Before(fn func(ResponseWriter)) + Status() int + WrittenSize() int } -var _ ResponseWriter = &Response{} +var _ ResponseWriter = (*Response)(nil) // Response represents a response type Response struct { http.ResponseWriter written int status int - befores []func(ResponseWriter) + beforeFuncs []func(ResponseWriter) beforeExecuted bool } // Write writes bytes to HTTP endpoint func (r *Response) Write(bs []byte) (int, error) { if !r.beforeExecuted { - for _, before := range r.befores { + for _, before := range r.beforeFuncs { before(r) } r.beforeExecuted = true @@ -51,18 +50,14 @@ func (r *Response) Write(bs []byte) (int, error) { return size, nil } -func (r *Response) Status() int { - return r.status -} - -func (r *Response) Size() int { +func (r *Response) WrittenSize() int { return r.written } // WriteHeader write status code func (r *Response) WriteHeader(statusCode int) { if !r.beforeExecuted { - for _, before := range r.befores { + for _, before := range r.beforeFuncs { before(r) } r.beforeExecuted = true @@ -80,6 +75,12 @@ func (r *Response) Flush() { } } +// Status returns status code written +// TODO: use WrittenStatus instead +func (r *Response) Status() int { + return r.status +} + // WrittenStatus returned status code written func (r *Response) WrittenStatus() int { return r.status @@ -87,17 +88,13 @@ func (r *Response) WrittenStatus() int { // Before allows for a function to be called before the ResponseWriter has been written to. This is // useful for setting headers or any other operations that must happen before a response has been written. -func (r *Response) Before(f func(ResponseWriter)) { - r.befores = append(r.befores, f) +func (r *Response) Before(fn func(ResponseWriter)) { + r.beforeFuncs = append(r.beforeFuncs, fn) } func WrapResponseWriter(resp http.ResponseWriter) *Response { if v, ok := resp.(*Response); ok { return v } - return &Response{ - ResponseWriter: resp, - status: 0, - befores: make([]func(ResponseWriter), 0), - } + return &Response{ResponseWriter: resp} }