diff --git a/src/website/routes.go b/src/website/routes.go index d4d9e60..01cfae7 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -31,10 +31,9 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt logPerf := TrackRequestPerf(c, perfCollector) defer logPerf() - res = h(c) + defer LogContextErrors(c, &res) - LogContextErrors(c, res) - return + return h(c) } }, } @@ -47,15 +46,14 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt logPerf := TrackRequestPerf(c, perfCollector) defer logPerf() + defer LogContextErrors(c, &res) + ok, errRes := LoadCommonWebsiteData(c) if !ok { return errRes } - res = h(c) - - LogContextErrors(c, res) - return + return h(c) } } @@ -67,20 +65,18 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt logPerf := TrackRequestPerf(c, perfCollector) defer logPerf() + defer LogContextErrors(c, &res) + ok, errRes := LoadCommonWebsiteData(c) if !ok { return errRes } if !c.CurrentProject.IsHMN() { - res = c.Redirect(hmnurl.Url(c.URL().String(), nil), http.StatusMovedPermanently) - return + return c.Redirect(hmnurl.Url(c.URL().String(), nil), http.StatusMovedPermanently) } - res = h(c) - - LogContextErrors(c, res) - return + return h(c) } } @@ -300,7 +296,7 @@ func TrackRequestPerf(c *RequestContext, perfCollector *perf.PerfCollector) (aft } } -func LogContextErrors(c *RequestContext, res ResponseData) { +func LogContextErrors(c *RequestContext, res *ResponseData) { for _, err := range res.Errors { c.Logger.Error().Err(err).Msg("error occurred during request") } diff --git a/src/website/routes_test.go b/src/website/routes_test.go new file mode 100644 index 0000000..86ea652 --- /dev/null +++ b/src/website/routes_test.go @@ -0,0 +1,54 @@ +package website + +import ( + "bytes" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +func TestLogContextErrors(t *testing.T) { + err1 := errors.New("test error 1") + err2 := errors.New("test error 2") + + var buf bytes.Buffer + logger := zerolog.New(&buf) + logger.Print("sanity check") + + assert.Contains(t, buf.String(), "sanity check") + + router := &Router{} + routes := RouteBuilder{ + Router: router, + Middleware: func(h Handler) Handler { + return func(c *RequestContext) (res ResponseData) { + c.Logger = &logger + defer LogContextErrors(c, &res) + return h(c) + } + }, + } + + routes.GET("^/test$", func(c *RequestContext) ResponseData { + return ErrorResponse(http.StatusInternalServerError, err1, err2) + }) + + srv := httptest.NewServer(router) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if assert.Nil(t, err) { + defer res.Body.Close() + + t.Logf("Log contents: %s", buf.String()) + + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + assert.Contains(t, buf.String(), err1.Error()) + assert.Contains(t, buf.String(), err2.Error()) + } +}