diff --git a/models/activities/notification.go b/models/activities/notification.go index 7c794564b6..29b02b5fdf 100644 --- a/models/activities/notification.go +++ b/models/activities/notification.go @@ -20,6 +20,7 @@ import ( "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" "xorm.io/builder" "xorm.io/xorm" @@ -821,3 +822,31 @@ func UpdateNotificationStatuses(ctx context.Context, user *user_model.User, curr Update(n) return err } + +// LoadIssuePullRequests loads all issues' pull requests if possible +func (nl NotificationList) LoadIssuePullRequests(ctx context.Context) error { + issues := make(map[int64]*issues_model.Issue, len(nl)) + for _, notification := range nl { + if notification.Issue != nil && notification.Issue.IsPull && notification.Issue.PullRequest == nil { + issues[notification.Issue.ID] = notification.Issue + } + } + + if len(issues) == 0 { + return nil + } + + pulls, err := issues_model.GetPullRequestByIssueIDs(ctx, util.KeysOfMap(issues)) + if err != nil { + return err + } + + for _, pull := range pulls { + if issue := issues[pull.IssueID]; issue != nil { + issue.PullRequest = pull + issue.PullRequest.Issue = issue + } + } + + return nil +} diff --git a/models/issues/issue.go b/models/issues/issue.go index 341ec8547a..9755ece702 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -200,20 +200,6 @@ func (issue *Issue) IsTimetrackerEnabled(ctx context.Context) bool { return issue.Repo.IsTimetrackerEnabled(ctx) } -// GetPullRequest returns the issue pull request -func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) { - if !issue.IsPull { - return nil, fmt.Errorf("Issue is not a pull request") - } - - pr, err = GetPullRequestByIssueID(db.DefaultContext, issue.ID) - if err != nil { - return nil, err - } - pr.Issue = issue - return pr, err -} - // LoadPoster loads poster func (issue *Issue) LoadPoster(ctx context.Context) (err error) { if issue.Poster == nil && issue.PosterID != 0 { diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index a932ac2554..f07dae89fd 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -370,6 +370,9 @@ func (issues IssueList) LoadPullRequests(ctx context.Context) error { for _, issue := range issues { issue.PullRequest = pullRequestMaps[issue.ID] + if issue.PullRequest != nil { + issue.PullRequest.Issue = issue + } } return nil } diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go index c209386e2e..2ee69cd323 100644 --- a/models/issues/pull_list.go +++ b/models/issues/pull_list.go @@ -212,3 +212,12 @@ func HasMergedPullRequestInRepo(ctx context.Context, repoID, posterID int64) (bo Limit(1). Get(new(Issue)) } + +// GetPullRequestByIssueIDs returns all pull requests by issue ids +func GetPullRequestByIssueIDs(ctx context.Context, issueIDs []int64) (PullRequestList, error) { + prs := make([]*PullRequest, 0, len(issueIDs)) + return prs, db.GetEngine(ctx). + Where("issue_id > 0"). + In("issue_id", issueIDs). + Find(&prs) +} diff --git a/models/issues/review.go b/models/issues/review.go index 18296c0dda..718708ec59 100644 --- a/models/issues/review.go +++ b/models/issues/review.go @@ -240,11 +240,11 @@ type CreateReviewOptions struct { // IsOfficialReviewer check if at least one of the provided reviewers can make official reviews in issue (counts towards required approvals) func IsOfficialReviewer(ctx context.Context, issue *Issue, reviewer *user_model.User) (bool, error) { - pr, err := GetPullRequestByIssueID(ctx, issue.ID) - if err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { return false, err } + pr := issue.PullRequest rule, err := git_model.GetFirstMatchProtectedBranchRule(ctx, pr.BaseRepoID, pr.BaseBranch) if err != nil { return false, err @@ -272,11 +272,10 @@ func IsOfficialReviewer(ctx context.Context, issue *Issue, reviewer *user_model. // IsOfficialReviewerTeam check if reviewer in this team can make official reviews in issue (counts towards required approvals) func IsOfficialReviewerTeam(ctx context.Context, issue *Issue, team *organization.Team) (bool, error) { - pr, err := GetPullRequestByIssueID(ctx, issue.ID) - if err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { return false, err } - pb, err := git_model.GetFirstMatchProtectedBranchRule(ctx, pr.BaseRepoID, pr.BaseBranch) + pb, err := git_model.GetFirstMatchProtectedBranchRule(ctx, issue.PullRequest.BaseRepoID, issue.PullRequest.BaseBranch) if err != nil { return false, err } diff --git a/modules/util/slice.go b/modules/util/slice.go index 6d63ab4a77..15ffd9deeb 100644 --- a/modules/util/slice.go +++ b/modules/util/slice.go @@ -45,3 +45,12 @@ func SliceSortedEqual[T comparable](s1, s2 []T) bool { func SliceRemoveAll[T comparable](slice []T, target T) []T { return slices.DeleteFunc(slice, func(t T) bool { return t == target }) } + +// TODO: Replace with "maps.Keys" once available, current it only in golang.org/x/exp/maps but not in standard library +func KeysOfMap[K comparable, V any](m map[K]V) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index fe9a491603..8c502467fd 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -864,10 +864,11 @@ func EditIssue(ctx *context.APIContext) { } if form.State != nil { if issue.IsPull { - if pr, err := issue.GetPullRequest(); err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { ctx.Error(http.StatusInternalServerError, "GetPullRequest", err) return - } else if pr.HasMerged { + } + if issue.PullRequest.HasMerged { ctx.Error(http.StatusPreconditionFailed, "MergedPRState", "cannot change state of this pull request, it was already merged") return } diff --git a/routers/api/v1/repo/issue_pin.go b/routers/api/v1/repo/issue_pin.go index d9b1bcc894..9bd609fcac 100644 --- a/routers/api/v1/repo/issue_pin.go +++ b/routers/api/v1/repo/issue_pin.go @@ -240,18 +240,12 @@ func ListPinnedPullRequests(ctx *context.APIContext) { } apiPrs := make([]*api.PullRequest, len(issues)) + if err := issues.LoadPullRequests(ctx); err != nil { + ctx.Error(http.StatusInternalServerError, "LoadPullRequests", err) + return + } for i, currentIssue := range issues { - pr, err := currentIssue.GetPullRequest() - if err != nil { - ctx.Error(http.StatusInternalServerError, "GetPullRequest", err) - return - } - - if err = pr.LoadIssue(ctx); err != nil { - ctx.Error(http.StatusInternalServerError, "LoadIssue", err) - return - } - + pr := currentIssue.PullRequest if err = pr.LoadAttributes(ctx); err != nil { ctx.Error(http.StatusInternalServerError, "LoadAttributes", err) return diff --git a/routers/web/user/notification.go b/routers/web/user/notification.go index 579287ffac..6fa7a0bb09 100644 --- a/routers/web/user/notification.go +++ b/routers/web/user/notification.go @@ -128,6 +128,12 @@ func getNotifications(ctx *context.Context) { ctx.ServerError("LoadIssues", err) return } + + if err = notifications.LoadIssuePullRequests(ctx); err != nil { + ctx.ServerError("LoadIssuePullRequests", err) + return + } + notifications = notifications.Without(failures) failCount += len(failures) diff --git a/services/convert/notification.go b/services/convert/notification.go index 7f724cf156..4d14777f40 100644 --- a/services/convert/notification.go +++ b/services/convert/notification.go @@ -61,8 +61,9 @@ func ToNotificationThread(ctx context.Context, n *activities_model.Notification) result.Subject.LatestCommentHTMLURL = comment.HTMLURL(ctx) } - pr, _ := n.Issue.GetPullRequest() - if pr != nil && pr.HasMerged { + if err := n.Issue.LoadPullRequest(ctx); err == nil && + n.Issue.PullRequest != nil && + n.Issue.PullRequest.HasMerged { result.Subject.State = "merged" } } diff --git a/services/pull/review.go b/services/pull/review.go index 64eea22eec..cdfb77b67b 100644 --- a/services/pull/review.go +++ b/services/pull/review.go @@ -264,11 +264,11 @@ func createCodeComment(ctx context.Context, doer *user_model.User, repo *repo_mo // SubmitReview creates a review out of the existing pending review or creates a new one if no pending review exist func SubmitReview(ctx context.Context, doer *user_model.User, gitRepo *git.Repository, issue *issues_model.Issue, reviewType issues_model.ReviewType, content, commitID string, attachmentUUIDs []string) (*issues_model.Review, *issues_model.Comment, error) { - pr, err := issue.GetPullRequest() - if err != nil { + if err := issue.LoadPullRequest(ctx); err != nil { return nil, nil, err } + pr := issue.PullRequest var stale bool if reviewType != issues_model.ReviewTypeApprove && reviewType != issues_model.ReviewTypeReject { stale = false diff --git a/templates/shared/issueicon.tmpl b/templates/shared/issueicon.tmpl index 94132825c4..a62714e988 100644 --- a/templates/shared/issueicon.tmpl +++ b/templates/shared/issueicon.tmpl @@ -1,15 +1,15 @@ {{if .IsPull}} - {{if and .PullRequest .PullRequest.HasMerged}} - {{svg "octicon-git-merge" 16 "text purple"}} - {{else if and .GetPullRequest .GetPullRequest.HasMerged}} - {{svg "octicon-git-merge" 16 "text purple"}} + {{if not .PullRequest}} + No PullRequest {{else}} {{if .IsClosed}} - {{svg "octicon-git-pull-request" 16 "text red"}} + {{if .PullRequest.HasMerged}} + {{svg "octicon-git-merge" 16 "text purple"}} + {{else}} + {{svg "octicon-git-pull-request" 16 "text red"}} + {{end}} {{else}} - {{if and .PullRequest (.PullRequest.IsWorkInProgress ctx)}} - {{svg "octicon-git-pull-request-draft" 16 "text grey"}} - {{else if and .GetPullRequest (.GetPullRequest.IsWorkInProgress ctx)}} + {{if .PullRequest.IsWorkInProgress ctx}} {{svg "octicon-git-pull-request-draft" 16 "text grey"}} {{else}} {{svg "octicon-git-pull-request" 16 "text green"}} diff --git a/tests/integration/pull_merge_test.go b/tests/integration/pull_merge_test.go index 2853b9f899..2eca9ce5eb 100644 --- a/tests/integration/pull_merge_test.go +++ b/tests/integration/pull_merge_test.go @@ -424,8 +424,8 @@ func TestConflictChecking(t *testing.T) { assert.NoError(t, err) issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{Title: "PR with conflict!"}) - conflictingPR, err := issues_model.GetPullRequestByIssueID(db.DefaultContext, issue.ID) - assert.NoError(t, err) + assert.NoError(t, issue.LoadPullRequest(db.DefaultContext)) + conflictingPR := issue.PullRequest // Ensure conflictedFiles is populated. assert.Len(t, conflictingPR.ConflictedFiles, 1) diff --git a/tests/integration/pull_update_test.go b/tests/integration/pull_update_test.go index e4b2ae65bd..5ff40cc6b7 100644 --- a/tests/integration/pull_update_test.go +++ b/tests/integration/pull_update_test.go @@ -175,8 +175,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *issues_mod assert.NoError(t, err) issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{Title: "Test Pull -to-update-"}) - pr, err := issues_model.GetPullRequestByIssueID(db.DefaultContext, issue.ID) - assert.NoError(t, err) + assert.NoError(t, issue.LoadPullRequest(db.DefaultContext)) - return pr + return issue.PullRequest }