diff --git a/models/activities/notification_list.go b/models/activities/notification_list.go index 957f9456e7..5858933391 100644 --- a/models/activities/notification_list.go +++ b/models/activities/notification_list.go @@ -14,6 +14,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/util" "xorm.io/builder" ) @@ -470,3 +471,31 @@ func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) { } return failures, nil } + +// LoadIssuePullRequests loads all issues' pull requests if possible +func (nl NotificationList) LoadIssuePullRequests(ctx context.Context) error { + issues := make(map[int64]*issues_model.Issue, len(nl)) + for _, notification := range nl { + if notification.Issue != nil && notification.Issue.IsPull && notification.Issue.PullRequest == nil { + issues[notification.Issue.ID] = notification.Issue + } + } + + if len(issues) == 0 { + return nil + } + + pulls, err := issues_model.GetPullRequestByIssueIDs(ctx, util.KeysOfMap(issues)) + if err != nil { + return err + } + + for _, pull := range pulls { + if issue := issues[pull.IssueID]; issue != nil { + issue.PullRequest = pull + issue.PullRequest.Issue = issue + } + } + + return nil +} diff --git a/models/issues/issue.go b/models/issues/issue.go index 563a780dcb..87c1c86eb1 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -193,20 +193,6 @@ func (issue *Issue) IsTimetrackerEnabled(ctx context.Context) bool { return issue.Repo.IsTimetrackerEnabled(ctx) } -// GetPullRequest returns the issue pull request -func (issue *Issue) GetPullRequest(ctx context.Context) (pr *PullRequest, err error) { - if !issue.IsPull { - return nil, fmt.Errorf("Issue is not a pull request") - } - - pr, err = GetPullRequestByIssueID(ctx, issue.ID) - if err != nil { - return nil, err - } - pr.Issue = issue - return pr, err -} - // LoadPoster loads poster func (issue *Issue) LoadPoster(ctx context.Context) (err error) { if issue.Poster == nil && issue.PosterID != 0 { diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index 41a90d133d..218891ad35 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -370,6 +370,9 @@ func (issues IssueList) LoadPullRequests(ctx context.Context) error { for _, issue := range issues { issue.PullRequest = pullRequestMaps[issue.ID] + if issue.PullRequest != nil { + issue.PullRequest.Issue = issue + } } return nil } diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go index c209386e2e..2ee69cd323 100644 --- a/models/issues/pull_list.go +++ b/models/issues/pull_list.go @@ -212,3 +212,12 @@ func HasMergedPullRequestInRepo(ctx context.Context, repoID, posterID int64) (bo Limit(1). Get(new(Issue)) } + +// GetPullRequestByIssueIDs returns all pull requests by issue ids +func GetPullRequestByIssueIDs(ctx context.Context, issueIDs []int64) (PullRequestList, error) { + prs := make([]*PullRequest, 0, len(issueIDs)) + return prs, db.GetEngine(ctx). + Where("issue_id > 0"). + In("issue_id", issueIDs). + Find(&prs) +} diff --git a/models/issues/review.go b/models/issues/review.go index fc110630e0..70aba0f94d 100644 --- a/models/issues/review.go +++ b/models/issues/review.go @@ -239,11 +239,11 @@ type CreateReviewOptions struct { // IsOfficialReviewer check if at least one of the provided reviewers can make official reviews in issue (counts towards required approvals) func IsOfficialReviewer(ctx context.Context, issue *Issue, reviewer *user_model.User) (bool, error) { - pr, err := GetPullRequestByIssueID(ctx, issue.ID) - if err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { return false, err } + pr := issue.PullRequest rule, err := git_model.GetFirstMatchProtectedBranchRule(ctx, pr.BaseRepoID, pr.BaseBranch) if err != nil { return false, err @@ -271,11 +271,10 @@ func IsOfficialReviewer(ctx context.Context, issue *Issue, reviewer *user_model. // IsOfficialReviewerTeam check if reviewer in this team can make official reviews in issue (counts towards required approvals) func IsOfficialReviewerTeam(ctx context.Context, issue *Issue, team *organization.Team) (bool, error) { - pr, err := GetPullRequestByIssueID(ctx, issue.ID) - if err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { return false, err } - pb, err := git_model.GetFirstMatchProtectedBranchRule(ctx, pr.BaseRepoID, pr.BaseBranch) + pb, err := git_model.GetFirstMatchProtectedBranchRule(ctx, issue.PullRequest.BaseRepoID, issue.PullRequest.BaseBranch) if err != nil { return false, err } diff --git a/modules/util/slice.go b/modules/util/slice.go index f00e84bf06..9c878c24be 100644 --- a/modules/util/slice.go +++ b/modules/util/slice.go @@ -54,7 +54,7 @@ func Sorted[S ~[]E, E cmp.Ordered](values S) S { return values } -// TODO: Replace with "maps.Values" once available +// TODO: Replace with "maps.Values" once available, current it only in golang.org/x/exp/maps but not in standard library func ValuesOfMap[K comparable, V any](m map[K]V) []V { values := make([]V, 0, len(m)) for _, v := range m { @@ -62,3 +62,12 @@ func ValuesOfMap[K comparable, V any](m map[K]V) []V { } return values } + +// TODO: Replace with "maps.Keys" once available, current it only in golang.org/x/exp/maps but not in standard library +func KeysOfMap[K comparable, V any](m map[K]V) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index 61a318baab..6934b34b24 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -872,10 +872,11 @@ func EditIssue(ctx *context.APIContext) { } if form.State != nil { if issue.IsPull { - if pr, err := issue.GetPullRequest(ctx); err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { ctx.Error(http.StatusInternalServerError, "GetPullRequest", err) return - } else if pr.HasMerged { + } + if issue.PullRequest.HasMerged { ctx.Error(http.StatusPreconditionFailed, "MergedPRState", "cannot change state of this pull request, it was already merged") return } diff --git a/routers/api/v1/repo/issue_pin.go b/routers/api/v1/repo/issue_pin.go index ff1135862b..8fcf670fd0 100644 --- a/routers/api/v1/repo/issue_pin.go +++ b/routers/api/v1/repo/issue_pin.go @@ -240,18 +240,12 @@ func ListPinnedPullRequests(ctx *context.APIContext) { } apiPrs := make([]*api.PullRequest, len(issues)) + if err := issues.LoadPullRequests(ctx); err != nil { + ctx.Error(http.StatusInternalServerError, "LoadPullRequests", err) + return + } for i, currentIssue := range issues { - pr, err := currentIssue.GetPullRequest(ctx) - if err != nil { - ctx.Error(http.StatusInternalServerError, "GetPullRequest", err) - return - } - - if err = pr.LoadIssue(ctx); err != nil { - ctx.Error(http.StatusInternalServerError, "LoadIssue", err) - return - } - + pr := currentIssue.PullRequest if err = pr.LoadAttributes(ctx); err != nil { ctx.Error(http.StatusInternalServerError, "LoadAttributes", err) return diff --git a/routers/web/user/notification.go b/routers/web/user/notification.go index 438462371b..28f9846d6b 100644 --- a/routers/web/user/notification.go +++ b/routers/web/user/notification.go @@ -144,6 +144,12 @@ func getNotifications(ctx *context.Context) { ctx.ServerError("LoadIssues", err) return } + + if err = notifications.LoadIssuePullRequests(ctx); err != nil { + ctx.ServerError("LoadIssuePullRequests", err) + return + } + notifications = notifications.Without(failures) failCount += len(failures) diff --git a/services/convert/notification.go b/services/convert/notification.go index 0b97530d8b..41063cf399 100644 --- a/services/convert/notification.go +++ b/services/convert/notification.go @@ -61,8 +61,9 @@ func ToNotificationThread(ctx context.Context, n *activities_model.Notification) result.Subject.LatestCommentHTMLURL = comment.HTMLURL(ctx) } - pr, _ := n.Issue.GetPullRequest(ctx) - if pr != nil && pr.HasMerged { + if err := n.Issue.LoadPullRequest(ctx); err == nil && + n.Issue.PullRequest != nil && + n.Issue.PullRequest.HasMerged { result.Subject.State = "merged" } } diff --git a/services/pull/review.go b/services/pull/review.go index 90d07c8358..8900ae2ab1 100644 --- a/services/pull/review.go +++ b/services/pull/review.go @@ -268,11 +268,11 @@ func createCodeComment(ctx context.Context, doer *user_model.User, repo *repo_mo // SubmitReview creates a review out of the existing pending review or creates a new one if no pending review exist func SubmitReview(ctx context.Context, doer *user_model.User, gitRepo *git.Repository, issue *issues_model.Issue, reviewType issues_model.ReviewType, content, commitID string, attachmentUUIDs []string) (*issues_model.Review, *issues_model.Comment, error) { - pr, err := issue.GetPullRequest(ctx) - if err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { return nil, nil, err } + pr := issue.PullRequest var stale bool if reviewType != issues_model.ReviewTypeApprove && reviewType != issues_model.ReviewTypeReject { stale = false diff --git a/templates/shared/issueicon.tmpl b/templates/shared/issueicon.tmpl index 089e80bd8b..a62714e988 100644 --- a/templates/shared/issueicon.tmpl +++ b/templates/shared/issueicon.tmpl @@ -1,15 +1,15 @@ {{if .IsPull}} - {{if and .PullRequest .PullRequest.HasMerged}} - {{svg "octicon-git-merge" 16 "text purple"}} - {{else if and (.GetPullRequest ctx) (.GetPullRequest ctx).HasMerged}} - {{svg "octicon-git-merge" 16 "text purple"}} + {{if not .PullRequest}} + No PullRequest {{else}} {{if .IsClosed}} - {{svg "octicon-git-pull-request" 16 "text red"}} + {{if .PullRequest.HasMerged}} + {{svg "octicon-git-merge" 16 "text purple"}} + {{else}} + {{svg "octicon-git-pull-request" 16 "text red"}} + {{end}} {{else}} - {{if and .PullRequest (.PullRequest.IsWorkInProgress ctx)}} - {{svg "octicon-git-pull-request-draft" 16 "text grey"}} - {{else if and (.GetPullRequest ctx) ((.GetPullRequest ctx).IsWorkInProgress ctx)}} + {{if .PullRequest.IsWorkInProgress ctx}} {{svg "octicon-git-pull-request-draft" 16 "text grey"}} {{else}} {{svg "octicon-git-pull-request" 16 "text green"}} diff --git a/tests/integration/pull_merge_test.go b/tests/integration/pull_merge_test.go index a04b4c98cd..daf411f452 100644 --- a/tests/integration/pull_merge_test.go +++ b/tests/integration/pull_merge_test.go @@ -516,8 +516,8 @@ func TestConflictChecking(t *testing.T) { assert.NoError(t, err) issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{Title: "PR with conflict!"}) - conflictingPR, err := issues_model.GetPullRequestByIssueID(db.DefaultContext, issue.ID) - assert.NoError(t, err) + assert.NoError(t, issue.LoadPullRequest(db.DefaultContext)) + conflictingPR := issue.PullRequest // Ensure conflictedFiles is populated. assert.Len(t, conflictingPR.ConflictedFiles, 1) diff --git a/tests/integration/pull_update_test.go b/tests/integration/pull_update_test.go index 078253ffb0..5ae241f3af 100644 --- a/tests/integration/pull_update_test.go +++ b/tests/integration/pull_update_test.go @@ -177,8 +177,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *issues_mod assert.NoError(t, err) issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{Title: "Test Pull -to-update-"}) - pr, err := issues_model.GetPullRequestByIssueID(db.DefaultContext, issue.ID) - assert.NoError(t, err) + assert.NoError(t, issue.LoadPullRequest(db.DefaultContext)) - return pr + return issue.PullRequest }