From a08856606e6dd9bfc8d6a7d07b4d2d27cf35afed Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 20 Nov 2021 00:31:29 +0800 Subject: [PATCH] Return 400 but not 500 when request archive with wrong format (#17691) (#17700) * Return 400 but not 500 when request archive with wrong format (#17691) * Remove bundle because it's not in this version --- integrations/api_repo_archive_test.go | 44 +++++++++++++++++++++++++++ routers/web/repo/repo.go | 6 +++- services/archiver/archiver.go | 18 ++++++++++- services/archiver/archiver_test.go | 13 ++++---- 4 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 integrations/api_repo_archive_test.go diff --git a/integrations/api_repo_archive_test.go b/integrations/api_repo_archive_test.go new file mode 100644 index 0000000000..4828c4aae1 --- /dev/null +++ b/integrations/api_repo_archive_test.go @@ -0,0 +1,44 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "fmt" + "io" + "net/http" + "net/url" + "testing" + + "code.gitea.io/gitea/models" + + "github.com/stretchr/testify/assert" +) + +func TestAPIDownloadArchive(t *testing.T) { + defer prepareTestEnv(t)() + + repo := models.AssertExistsAndLoadBean(t, &models.Repository{ID: 1}).(*models.Repository) + user2 := models.AssertExistsAndLoadBean(t, &models.User{ID: 2}).(*models.User) + session := loginUser(t, user2.LowerName) + token := getTokenForLoggedInUser(t, session) + + link, _ := url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.zip", user2.Name, repo.Name)) + link.RawQuery = url.Values{"token": {token}}.Encode() + resp := MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK) + bs, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.EqualValues(t, 320, len(bs)) + + link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.tar.gz", user2.Name, repo.Name)) + link.RawQuery = url.Values{"token": {token}}.Encode() + resp = MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK) + bs, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.EqualValues(t, 266, len(bs)) + + link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master", user2.Name, repo.Name)) + link.RawQuery = url.Values{"token": {token}}.Encode() + MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusBadRequest) +} diff --git a/routers/web/repo/repo.go b/routers/web/repo/repo.go index 919fd4620d..8d5546f8cd 100644 --- a/routers/web/repo/repo.go +++ b/routers/web/repo/repo.go @@ -371,7 +371,11 @@ func Download(ctx *context.Context) { uri := ctx.Params("*") aReq, err := archiver_service.NewRequest(ctx.Repo.Repository.ID, ctx.Repo.GitRepo, uri) if err != nil { - ctx.ServerError("archiver_service.NewRequest", err) + if errors.Is(err, archiver_service.ErrUnknownArchiveFormat{}) { + ctx.Error(http.StatusBadRequest, err.Error()) + } else { + ctx.ServerError("archiver_service.NewRequest", err) + } return } if aReq == nil { diff --git a/services/archiver/archiver.go b/services/archiver/archiver.go index d64ec2bdea..cf9c49764b 100644 --- a/services/archiver/archiver.go +++ b/services/archiver/archiver.go @@ -38,6 +38,22 @@ type ArchiveRequest struct { // the way to 64. var shaRegex = regexp.MustCompile(`^[0-9a-f]{4,64}$`) +// ErrUnknownArchiveFormat request archive format is not supported +type ErrUnknownArchiveFormat struct { + RequestFormat string +} + +// Error implements error +func (err ErrUnknownArchiveFormat) Error() string { + return fmt.Sprintf("unknown format: %s", err.RequestFormat) +} + +// Is implements error +func (ErrUnknownArchiveFormat) Is(err error) bool { + _, ok := err.(ErrUnknownArchiveFormat) + return ok +} + // NewRequest creates an archival request, based on the URI. The // resulting ArchiveRequest is suitable for being passed to ArchiveRepository() // if it's determined that the request still needs to be satisfied. @@ -55,7 +71,7 @@ func NewRequest(repoID int64, repo *git.Repository, uri string) (*ArchiveRequest ext = ".tar.gz" r.Type = git.TARGZ default: - return nil, fmt.Errorf("Unknown format: %s", uri) + return nil, ErrUnknownArchiveFormat{RequestFormat: uri} } r.refName = strings.TrimSuffix(uri, ext) diff --git a/services/archiver/archiver_test.go b/services/archiver/archiver_test.go index 3f3f369987..f9aff7aec3 100644 --- a/services/archiver/archiver_test.go +++ b/services/archiver/archiver_test.go @@ -5,6 +5,7 @@ package archiver import ( + "errors" "path/filepath" "testing" "time" @@ -19,10 +20,6 @@ func TestMain(m *testing.M) { models.MainTest(m, filepath.Join("..", "..")) } -func waitForCount(t *testing.T, num int) { - -} - func TestArchive_Basic(t *testing.T) { assert.NoError(t, models.PrepareTestDatabase()) @@ -83,11 +80,8 @@ func TestArchive_Basic(t *testing.T) { inFlight[2] = secondReq ArchiveRepository(zipReq) - waitForCount(t, 1) ArchiveRepository(tgzReq) - waitForCount(t, 2) ArchiveRepository(secondReq) - waitForCount(t, 3) // Make sure sending an unprocessed request through doesn't affect the queue // count. @@ -132,3 +126,8 @@ func TestArchive_Basic(t *testing.T) { assert.NotEqual(t, zipReq.GetArchiveName(), tgzReq.GetArchiveName()) assert.NotEqual(t, zipReq.GetArchiveName(), secondReq.GetArchiveName()) } + +func TestErrUnknownArchiveFormat(t *testing.T) { + var err = ErrUnknownArchiveFormat{RequestFormat: "master"} + assert.True(t, errors.Is(err, ErrUnknownArchiveFormat{})) +}