From ca414a7ccf5e26272662e360c44ac50221a0f2d4 Mon Sep 17 00:00:00 2001 From: Thomas Desveaux Date: Tue, 4 Jun 2024 13:56:59 +0200 Subject: [PATCH] Fix NuGet Package API for $filter with Id equality (#31188) (#31242) Backport #31188 Fixes issue when running `choco info pkgname` where `pkgname` is also a substring of another package Id. Relates to #31168 --- This might fix the issue linked, but I'd like to test it with more choco commands before closing the issue in case I find other problems if that's ok. I'm pretty inexperienced with Go, so feel free to nitpick things. Not sure I handled [this](https://github.com/tdesveaux/gitea/blob/70f87e11b5caf1ee441ae71c7eba1831f45897d4/routers/api/packages/nuget/nuget.go#L135-L137) in the best way, so looking for feedback on if I should fix the underlying issue (`nil` might be a better default for `Value`?). Co-authored-by: KN4CK3R --- routers/api/packages/nuget/nuget.go | 48 +++++---- tests/integration/api_packages_nuget_test.go | 102 ++++++++++++++++--- 2 files changed, 115 insertions(+), 35 deletions(-) diff --git a/routers/api/packages/nuget/nuget.go b/routers/api/packages/nuget/nuget.go index 26b0ae226e..3633d0d007 100644 --- a/routers/api/packages/nuget/nuget.go +++ b/routers/api/packages/nuget/nuget.go @@ -96,20 +96,34 @@ func FeedCapabilityResource(ctx *context.Context) { xmlResponse(ctx, http.StatusOK, Metadata) } -var searchTermExtract = regexp.MustCompile(`'([^']+)'`) +var ( + searchTermExtract = regexp.MustCompile(`'([^']+)'`) + searchTermExact = regexp.MustCompile(`\s+eq\s+'`) +) -func getSearchTerm(ctx *context.Context) string { +func getSearchTerm(ctx *context.Context) packages_model.SearchValue { searchTerm := strings.Trim(ctx.FormTrim("searchTerm"), "'") - if searchTerm == "" { - // $filter contains a query like: - // (((Id ne null) and substringof('microsoft',tolower(Id))) - // We don't support these queries, just extract the search term. - match := searchTermExtract.FindStringSubmatch(ctx.FormTrim("$filter")) - if len(match) == 2 { - searchTerm = strings.TrimSpace(match[1]) + if searchTerm != "" { + return packages_model.SearchValue{ + Value: searchTerm, + ExactMatch: false, } } - return searchTerm + + // $filter contains a query like: + // (((Id ne null) and substringof('microsoft',tolower(Id))) + // https://www.odata.org/documentation/odata-version-2-0/uri-conventions/ section 4.5 + // We don't support these queries, just extract the search term. + filter := ctx.FormTrim("$filter") + match := searchTermExtract.FindStringSubmatch(filter) + if len(match) == 2 { + return packages_model.SearchValue{ + Value: strings.TrimSpace(match[1]), + ExactMatch: searchTermExact.MatchString(filter), + } + } + + return packages_model.SearchValue{} } // https://github.com/NuGet/NuGet.Client/blob/dev/src/NuGet.Core/NuGet.Protocol/LegacyFeed/V2FeedQueryBuilder.cs @@ -118,11 +132,9 @@ func SearchServiceV2(ctx *context.Context) { paginator := db.NewAbsoluteListOptions(skip, take) pvs, total, err := packages_model.SearchLatestVersions(ctx, &packages_model.PackageSearchOptions{ - OwnerID: ctx.Package.Owner.ID, - Type: packages_model.TypeNuGet, - Name: packages_model.SearchValue{ - Value: getSearchTerm(ctx), - }, + OwnerID: ctx.Package.Owner.ID, + Type: packages_model.TypeNuGet, + Name: getSearchTerm(ctx), IsInternal: optional.Some(false), Paginator: paginator, }) @@ -169,10 +181,8 @@ func SearchServiceV2(ctx *context.Context) { // http://docs.oasis-open.org/odata/odata/v4.0/errata03/os/complete/part2-url-conventions/odata-v4.0-errata03-os-part2-url-conventions-complete.html#_Toc453752351 func SearchServiceV2Count(ctx *context.Context) { count, err := nuget_model.CountPackages(ctx, &packages_model.PackageSearchOptions{ - OwnerID: ctx.Package.Owner.ID, - Name: packages_model.SearchValue{ - Value: getSearchTerm(ctx), - }, + OwnerID: ctx.Package.Owner.ID, + Name: getSearchTerm(ctx), IsInternal: optional.Some(false), }) if err != nil { diff --git a/tests/integration/api_packages_nuget_test.go b/tests/integration/api_packages_nuget_test.go index 83947ff967..630b4de3f9 100644 --- a/tests/integration/api_packages_nuget_test.go +++ b/tests/integration/api_packages_nuget_test.go @@ -429,22 +429,33 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`) t.Run("SearchService", func(t *testing.T) { cases := []struct { - Query string - Skip int - Take int - ExpectedTotal int64 - ExpectedResults int + Query string + Skip int + Take int + ExpectedTotal int64 + ExpectedResults int + ExpectedExactMatch bool }{ - {"", 0, 0, 1, 1}, - {"", 0, 10, 1, 1}, - {"gitea", 0, 10, 0, 0}, - {"test", 0, 10, 1, 1}, - {"test", 1, 10, 1, 0}, + {"", 0, 0, 4, 4, false}, + {"", 0, 10, 4, 4, false}, + {"gitea", 0, 10, 0, 0, false}, + {"test", 0, 10, 1, 1, false}, + {"test", 1, 10, 1, 0, false}, + {"almost.similar", 0, 0, 3, 3, true}, } - req := NewRequestWithBody(t, "PUT", url, createPackage(packageName, "1.0.99")). - AddBasicAuth(user.Name) - MakeRequest(t, req, http.StatusCreated) + fakePackages := []string{ + packageName, + "almost.similar.dependency", + "almost.similar", + "almost.similar.dependant", + } + + for _, fakePackageName := range fakePackages { + req := NewRequestWithBody(t, "PUT", url, createPackage(fakePackageName, "1.0.99")). + AddBasicAuth(user.Name) + MakeRequest(t, req, http.StatusCreated) + } t.Run("v2", func(t *testing.T) { t.Run("Search()", func(t *testing.T) { @@ -491,6 +502,63 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`) } }) + t.Run("Packages()", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + t.Run("substringof", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + for i, c := range cases { + req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)). + AddBasicAuth(user.Name) + resp := MakeRequest(t, req, http.StatusOK) + + var result FeedResponse + decodeXML(t, resp, &result) + + assert.Equal(t, c.ExpectedTotal, result.Count, "case %d: unexpected total hits", i) + assert.Len(t, result.Entries, c.ExpectedResults, "case %d: unexpected result count", i) + + req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)). + AddBasicAuth(user.Name) + resp = MakeRequest(t, req, http.StatusOK) + + assert.Equal(t, strconv.FormatInt(c.ExpectedTotal, 10), resp.Body.String(), "case %d: unexpected total hits", i) + } + }) + + t.Run("IdEq", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + for i, c := range cases { + if c.Query == "" { + // Ignore the `tolower(Id) eq ''` as it's unlikely to happen + continue + } + req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)). + AddBasicAuth(user.Name) + resp := MakeRequest(t, req, http.StatusOK) + + var result FeedResponse + decodeXML(t, resp, &result) + + expectedCount := 0 + if c.ExpectedExactMatch { + expectedCount = 1 + } + + assert.Equal(t, int64(expectedCount), result.Count, "case %d: unexpected total hits", i) + assert.Len(t, result.Entries, expectedCount, "case %d: unexpected result count", i) + + req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)). + AddBasicAuth(user.Name) + resp = MakeRequest(t, req, http.StatusOK) + + assert.Equal(t, strconv.FormatInt(int64(expectedCount), 10), resp.Body.String(), "case %d: unexpected total hits", i) + } + }) + }) + t.Run("Next", func(t *testing.T) { req := NewRequest(t, "GET", fmt.Sprintf("%s/Search()?searchTerm='test'&$skip=0&$top=1", url)). AddBasicAuth(user.Name) @@ -548,9 +616,11 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`) }) }) - req = NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, packageName, "1.0.99")). - AddBasicAuth(user.Name) - MakeRequest(t, req, http.StatusNoContent) + for _, fakePackageName := range fakePackages { + req := NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, fakePackageName, "1.0.99")). + AddBasicAuth(user.Name) + MakeRequest(t, req, http.StatusNoContent) + } }) t.Run("RegistrationService", func(t *testing.T) {