1
1
mirror of https://github.com/go-gitea/gitea synced 2025-12-07 13:28:25 +00:00

Merge branch 'main' into api-repo-actions

This commit is contained in:
Chester
2024-04-03 11:11:41 -04:00
committed by GitHub
2487 changed files with 63536 additions and 43295 deletions
+102
View File
@@ -0,0 +1,102 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"fmt"
"net/http"
"strings"
"time"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"github.com/golang-jwt/jwt/v5"
)
type actionsClaims struct {
jwt.RegisteredClaims
Scp string `json:"scp"`
TaskID int64
RunID int64
JobID int64
Ac string `json:"ac"`
}
type actionsCacheScope struct {
Scope string
Permission actionsCachePermission
}
type actionsCachePermission int
const (
actionsCachePermissionRead = 1 << iota
actionsCachePermissionWrite
)
func CreateAuthorizationToken(taskID, runID, jobID int64) (string, error) {
now := time.Now()
ac, err := json.Marshal(&[]actionsCacheScope{
{
Scope: "",
Permission: actionsCachePermissionWrite,
},
})
if err != nil {
return "", err
}
claims := actionsClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(24 * time.Hour)),
NotBefore: jwt.NewNumericDate(now),
},
Scp: fmt.Sprintf("Actions.Results:%d:%d", runID, jobID),
Ac: string(ac),
TaskID: taskID,
RunID: runID,
JobID: jobID,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(setting.GetGeneralTokenSigningSecret())
if err != nil {
return "", err
}
return tokenString, nil
}
func ParseAuthorizationToken(req *http.Request) (int64, error) {
h := req.Header.Get("Authorization")
if h == "" {
return 0, nil
}
parts := strings.SplitN(h, " ", 2)
if len(parts) != 2 {
log.Error("split token failed: %s", h)
return 0, fmt.Errorf("split token failed")
}
token, err := jwt.ParseWithClaims(parts[1], &actionsClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return setting.GetGeneralTokenSigningSecret(), nil
})
if err != nil {
return 0, err
}
c, ok := token.Claims.(*actionsClaims)
if !token.Valid || !ok {
return 0, fmt.Errorf("invalid token claim")
}
return c.TaskID, nil
}
+64
View File
@@ -0,0 +1,64 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"net/http"
"testing"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/setting"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)
func TestCreateAuthorizationToken(t *testing.T) {
var taskID int64 = 23
token, err := CreateAuthorizationToken(taskID, 1, 2)
assert.Nil(t, err)
assert.NotEqual(t, "", token)
claims := jwt.MapClaims{}
_, err = jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (interface{}, error) {
return setting.GetGeneralTokenSigningSecret(), nil
})
assert.Nil(t, err)
scp, ok := claims["scp"]
assert.True(t, ok, "Has scp claim in jwt token")
assert.Contains(t, scp, "Actions.Results:1:2")
taskIDClaim, ok := claims["TaskID"]
assert.True(t, ok, "Has TaskID claim in jwt token")
assert.Equal(t, float64(taskID), taskIDClaim, "Supplied taskid must match stored one")
acClaim, ok := claims["ac"]
assert.True(t, ok, "Has ac claim in jwt token")
ac, ok := acClaim.(string)
assert.True(t, ok, "ac claim is a string for buildx gha cache")
scopes := []actionsCacheScope{}
err = json.Unmarshal([]byte(ac), &scopes)
assert.NoError(t, err, "ac claim is a json list for buildx gha cache")
assert.GreaterOrEqual(t, len(scopes), 1, "Expected at least one action cache scope for buildx gha cache")
}
func TestParseAuthorizationToken(t *testing.T) {
var taskID int64 = 23
token, err := CreateAuthorizationToken(taskID, 1, 2)
assert.Nil(t, err)
assert.NotEqual(t, "", token)
headers := http.Header{}
headers.Set("Authorization", "Bearer "+token)
rTaskID, err := ParseAuthorizationToken(&http.Request{
Header: headers,
})
assert.Nil(t, err)
assert.Equal(t, taskID, rTaskID)
}
func TestParseAuthorizationTokenNoAuthHeader(t *testing.T) {
headers := http.Header{}
rTaskID, err := ParseAuthorizationToken(&http.Request{
Header: headers,
})
assert.Nil(t, err)
assert.Equal(t, int64(0), rTaskID)
}
+41 -5
View File
@@ -20,23 +20,59 @@ func Cleanup(taskCtx context.Context, olderThan time.Duration) error {
return CleanupArtifacts(taskCtx)
}
// CleanupArtifacts removes expired artifacts and set records expired status
// CleanupArtifacts removes expired add need-deleted artifacts and set records expired status
func CleanupArtifacts(taskCtx context.Context) error {
if err := cleanExpiredArtifacts(taskCtx); err != nil {
return err
}
return cleanNeedDeleteArtifacts(taskCtx)
}
func cleanExpiredArtifacts(taskCtx context.Context) error {
artifacts, err := actions.ListNeedExpiredArtifacts(taskCtx)
if err != nil {
return err
}
log.Info("Found %d expired artifacts", len(artifacts))
for _, artifact := range artifacts {
if err := storage.ActionsArtifacts.Delete(artifact.StoragePath); err != nil {
log.Error("Cannot delete artifact %d: %v", artifact.ID, err)
continue
}
if err := actions.SetArtifactExpired(taskCtx, artifact.ID); err != nil {
log.Error("Cannot set artifact %d expired: %v", artifact.ID, err)
continue
}
if err := storage.ActionsArtifacts.Delete(artifact.StoragePath); err != nil {
log.Error("Cannot delete artifact %d: %v", artifact.ID, err)
continue
}
log.Info("Artifact %d set expired", artifact.ID)
}
return nil
}
// deleteArtifactBatchSize is the batch size of deleting artifacts
const deleteArtifactBatchSize = 100
func cleanNeedDeleteArtifacts(taskCtx context.Context) error {
for {
artifacts, err := actions.ListPendingDeleteArtifacts(taskCtx, deleteArtifactBatchSize)
if err != nil {
return err
}
log.Info("Found %d artifacts pending deletion", len(artifacts))
for _, artifact := range artifacts {
if err := actions.SetArtifactDeleted(taskCtx, artifact.ID); err != nil {
log.Error("Cannot set artifact %d deleted: %v", artifact.ID, err)
continue
}
if err := storage.ActionsArtifacts.Delete(artifact.StoragePath); err != nil {
log.Error("Cannot delete artifact %d: %v", artifact.ID, err)
continue
}
log.Info("Artifact %d set deleted", artifact.ID)
}
if len(artifacts) < deleteArtifactBatchSize {
log.Debug("No more artifacts pending deletion")
break
}
}
return nil
}
+2 -2
View File
@@ -33,7 +33,7 @@ func StopEndlessTasks(ctx context.Context) error {
}
func stopTasks(ctx context.Context, opts actions_model.FindTaskOptions) error {
tasks, err := actions_model.FindTasks(ctx, opts)
tasks, err := db.Find[actions_model.ActionTask](ctx, opts)
if err != nil {
return fmt.Errorf("find tasks: %w", err)
}
@@ -74,7 +74,7 @@ func stopTasks(ctx context.Context, opts actions_model.FindTaskOptions) error {
// CancelAbandonedJobs cancels the jobs which have waiting status, but haven't been picked by a runner for a long time
func CancelAbandonedJobs(ctx context.Context) error {
jobs, _, err := actions_model.FindRunJobs(ctx, actions_model.FindRunJobOptions{
jobs, err := db.Find[actions_model.ActionRunJob](ctx, actions_model.FindRunJobOptions{
Statuses: []actions_model.Status{actions_model.StatusWaiting, actions_model.StatusBlocked},
UpdatedBefore: timeutil.TimeStamp(time.Now().Add(-setting.Actions.AbandonedJobTimeout).Unix()),
})
+10 -2
View File
@@ -12,6 +12,7 @@ import (
"code.gitea.io/gitea/models/db"
git_model "code.gitea.io/gitea/models/git"
user_model "code.gitea.io/gitea/models/user"
git "code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
api "code.gitea.io/gitea/modules/structs"
webhook_module "code.gitea.io/gitea/modules/webhook"
@@ -63,6 +64,9 @@ func createCommitStatus(ctx context.Context, job *actions_model.ActionRunJob) er
return fmt.Errorf("head of pull request is missing in event payload")
}
sha = payload.PullRequest.Head.Sha
case webhook_module.HookEventRelease:
event = string(run.Event)
sha = run.CommitSHA
default:
return nil
}
@@ -75,7 +79,7 @@ func createCommitStatus(ctx context.Context, job *actions_model.ActionRunJob) er
}
ctxname := fmt.Sprintf("%s / %s (%s)", runName, job.Name, event)
state := toCommitStatus(job.Status)
if statuses, _, err := git_model.GetLatestCommitStatus(ctx, repo.ID, sha, db.ListOptions{ListAll: true}); err == nil {
if statuses, _, err := git_model.GetLatestCommitStatus(ctx, repo.ID, sha, db.ListOptionsAll); err == nil {
for _, v := range statuses {
if v.Context == ctxname {
if v.State == state {
@@ -114,9 +118,13 @@ func createCommitStatus(ctx context.Context, job *actions_model.ActionRunJob) er
}
creator := user_model.NewActionsUser()
commitID, err := git.NewIDFromString(sha)
if err != nil {
return fmt.Errorf("HashTypeInterfaceFromHashString: %w", err)
}
if err := git_model.NewCommitStatus(ctx, git_model.NewCommitStatusOptions{
Repo: repo,
SHA: sha,
SHA: commitID,
Creator: creator,
CommitStatus: &git_model.CommitStatus{
SHA: sha,
+21 -2
View File
@@ -7,12 +7,14 @@ import (
"context"
"errors"
"fmt"
"strings"
actions_model "code.gitea.io/gitea/models/actions"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/graceful"
"code.gitea.io/gitea/modules/queue"
"github.com/nektos/act/pkg/jobparser"
"xorm.io/builder"
)
@@ -44,7 +46,7 @@ func jobEmitterQueueHandler(items ...*jobUpdate) []*jobUpdate {
}
func checkJobsOfRun(ctx context.Context, runID int64) error {
jobs, _, err := actions_model.FindRunJobs(ctx, actions_model.FindRunJobOptions{RunID: runID})
jobs, err := db.Find[actions_model.ActionRunJob](ctx, actions_model.FindRunJobOptions{RunID: runID})
if err != nil {
return err
}
@@ -76,12 +78,15 @@ func checkJobsOfRun(ctx context.Context, runID int64) error {
type jobStatusResolver struct {
statuses map[int64]actions_model.Status
needs map[int64][]int64
jobMap map[int64]*actions_model.ActionRunJob
}
func newJobStatusResolver(jobs actions_model.ActionJobList) *jobStatusResolver {
idToJobs := make(map[string][]*actions_model.ActionRunJob, len(jobs))
jobMap := make(map[int64]*actions_model.ActionRunJob)
for _, job := range jobs {
idToJobs[job.JobID] = append(idToJobs[job.JobID], job)
jobMap[job.ID] = job
}
statuses := make(map[int64]actions_model.Status, len(jobs))
@@ -97,6 +102,7 @@ func newJobStatusResolver(jobs actions_model.ActionJobList) *jobStatusResolver {
return &jobStatusResolver{
statuses: statuses,
needs: needs,
jobMap: jobMap,
}
}
@@ -135,7 +141,20 @@ func (r *jobStatusResolver) resolve() map[int64]actions_model.Status {
if allSucceed {
ret[id] = actions_model.StatusWaiting
} else {
ret[id] = actions_model.StatusSkipped
// If a job's "if" condition is "always()", the job should always run even if some of its dependencies did not succeed.
// See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idneeds
always := false
if wfJobs, _ := jobparser.Parse(r.jobMap[id].WorkflowPayload); len(wfJobs) == 1 {
_, wfJob := wfJobs[0].Job()
expr := strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(wfJob.If.Value, "${{"), "}}"))
always = expr == "always()"
}
if always {
ret[id] = actions_model.StatusWaiting
} else {
ret[id] = actions_model.StatusSkipped
}
}
}
}
+56
View File
@@ -70,6 +70,62 @@ func Test_jobStatusResolver_Resolve(t *testing.T) {
},
want: map[int64]actions_model.Status{},
},
{
name: "with ${{ always() }} condition",
jobs: actions_model.ActionJobList{
{ID: 1, JobID: "job1", Status: actions_model.StatusFailure, Needs: []string{}},
{ID: 2, JobID: "job2", Status: actions_model.StatusBlocked, Needs: []string{"job1"}, WorkflowPayload: []byte(
`
name: test
on: push
jobs:
job2:
runs-on: ubuntu-latest
needs: job1
if: ${{ always() }}
steps:
- run: echo "always run"
`)},
},
want: map[int64]actions_model.Status{2: actions_model.StatusWaiting},
},
{
name: "with always() condition",
jobs: actions_model.ActionJobList{
{ID: 1, JobID: "job1", Status: actions_model.StatusFailure, Needs: []string{}},
{ID: 2, JobID: "job2", Status: actions_model.StatusBlocked, Needs: []string{"job1"}, WorkflowPayload: []byte(
`
name: test
on: push
jobs:
job2:
runs-on: ubuntu-latest
needs: job1
if: always()
steps:
- run: echo "always run"
`)},
},
want: map[int64]actions_model.Status{2: actions_model.StatusWaiting},
},
{
name: "without always() condition",
jobs: actions_model.ActionJobList{
{ID: 1, JobID: "job1", Status: actions_model.StatusFailure, Needs: []string{}},
{ID: 2, JobID: "job2", Status: actions_model.StatusBlocked, Needs: []string{"job1"}, WorkflowPayload: []byte(
`
name: test
on: push
jobs:
job2:
runs-on: ubuntu-latest
needs: job1
steps:
- run: echo "not always run"
`)},
},
want: map[int64]actions_model.Status{2: actions_model.StatusSkipped},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+224 -35
View File
@@ -55,6 +55,47 @@ func (n *actionsNotifier) NewIssue(ctx context.Context, issue *issues_model.Issu
}).Notify(withMethod(ctx, "NewIssue"))
}
// IssueChangeContent notifies change content of issue
func (n *actionsNotifier) IssueChangeContent(ctx context.Context, doer *user_model.User, issue *issues_model.Issue, oldContent string) {
ctx = withMethod(ctx, "IssueChangeContent")
var err error
if err = issue.LoadRepo(ctx); err != nil {
log.Error("LoadRepo: %v", err)
return
}
permission, _ := access_model.GetUserRepoPermission(ctx, issue.Repo, issue.Poster)
if issue.IsPull {
if err = issue.LoadPullRequest(ctx); err != nil {
log.Error("loadPullRequest: %v", err)
return
}
newNotifyInputFromIssue(issue, webhook_module.HookEventPullRequest).
WithDoer(doer).
WithPayload(&api.PullRequestPayload{
Action: api.HookIssueEdited,
Index: issue.Index,
PullRequest: convert.ToAPIPullRequest(ctx, issue.PullRequest, nil),
Repository: convert.ToRepo(ctx, issue.Repo, access_model.Permission{AccessMode: perm_model.AccessModeNone}),
Sender: convert.ToUser(ctx, doer, nil),
}).
WithPullRequest(issue.PullRequest).
Notify(ctx)
return
}
newNotifyInputFromIssue(issue, webhook_module.HookEventIssues).
WithDoer(doer).
WithPayload(&api.IssuePayload{
Action: api.HookIssueEdited,
Index: issue.Index,
Issue: convert.ToAPIIssue(ctx, issue),
Repository: convert.ToRepo(ctx, issue.Repo, permission),
Sender: convert.ToUser(ctx, doer, nil),
}).
Notify(ctx)
}
// IssueChangeStatus notifies close or reopen issue to notifiers
func (n *actionsNotifier) IssueChangeStatus(ctx context.Context, doer *user_model.User, commitID string, issue *issues_model.Issue, _ *issues_model.Comment, isClosed bool) {
ctx = withMethod(ctx, "IssueChangeStatus")
@@ -101,11 +142,58 @@ func (n *actionsNotifier) IssueChangeStatus(ctx context.Context, doer *user_mode
Notify(ctx)
}
// IssueChangeAssignee notifies assigned or unassigned to notifiers
func (n *actionsNotifier) IssueChangeAssignee(ctx context.Context, doer *user_model.User, issue *issues_model.Issue, assignee *user_model.User, removed bool, comment *issues_model.Comment) {
ctx = withMethod(ctx, "IssueChangeAssignee")
var action api.HookIssueAction
if removed {
action = api.HookIssueUnassigned
} else {
action = api.HookIssueAssigned
}
hookEvent := webhook_module.HookEventIssueAssign
if issue.IsPull {
hookEvent = webhook_module.HookEventPullRequestAssign
}
notifyIssueChange(ctx, doer, issue, hookEvent, action)
}
// IssueChangeMilestone notifies assignee to notifiers
func (n *actionsNotifier) IssueChangeMilestone(ctx context.Context, doer *user_model.User, issue *issues_model.Issue, oldMilestoneID int64) {
ctx = withMethod(ctx, "IssueChangeMilestone")
var action api.HookIssueAction
if issue.MilestoneID > 0 {
action = api.HookIssueMilestoned
} else {
action = api.HookIssueDemilestoned
}
hookEvent := webhook_module.HookEventIssueMilestone
if issue.IsPull {
hookEvent = webhook_module.HookEventPullRequestMilestone
}
notifyIssueChange(ctx, doer, issue, hookEvent, action)
}
func (n *actionsNotifier) IssueChangeLabels(ctx context.Context, doer *user_model.User, issue *issues_model.Issue,
_, _ []*issues_model.Label,
) {
ctx = withMethod(ctx, "IssueChangeLabels")
hookEvent := webhook_module.HookEventIssueLabel
if issue.IsPull {
hookEvent = webhook_module.HookEventPullRequestLabel
}
notifyIssueChange(ctx, doer, issue, hookEvent, api.HookIssueLabelUpdated)
}
func notifyIssueChange(ctx context.Context, doer *user_model.User, issue *issues_model.Issue, event webhook_module.HookEventType, action api.HookIssueAction) {
var err error
if err = issue.LoadRepo(ctx); err != nil {
log.Error("LoadRepo: %v", err)
@@ -117,20 +205,15 @@ func (n *actionsNotifier) IssueChangeLabels(ctx context.Context, doer *user_mode
return
}
permission, _ := access_model.GetUserRepoPermission(ctx, issue.Repo, issue.Poster)
if issue.IsPull {
if err = issue.LoadPullRequest(ctx); err != nil {
log.Error("loadPullRequest: %v", err)
return
}
if err = issue.PullRequest.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return
}
newNotifyInputFromIssue(issue, webhook_module.HookEventPullRequestLabel).
newNotifyInputFromIssue(issue, event).
WithDoer(doer).
WithPayload(&api.PullRequestPayload{
Action: api.HookIssueLabelUpdated,
Action: action,
Index: issue.Index,
PullRequest: convert.ToAPIPullRequest(ctx, issue.PullRequest, nil),
Repository: convert.ToRepo(ctx, issue.Repo, access_model.Permission{AccessMode: perm_model.AccessModeNone}),
@@ -140,10 +223,11 @@ func (n *actionsNotifier) IssueChangeLabels(ctx context.Context, doer *user_mode
Notify(ctx)
return
}
newNotifyInputFromIssue(issue, webhook_module.HookEventIssueLabel).
permission, _ := access_model.GetUserRepoPermission(ctx, issue.Repo, issue.Poster)
newNotifyInputFromIssue(issue, event).
WithDoer(doer).
WithPayload(&api.IssuePayload{
Action: api.HookIssueLabelUpdated,
Action: action,
Index: issue.Index,
Issue: convert.ToAPIIssue(ctx, issue),
Repository: convert.ToRepo(ctx, issue.Repo, permission),
@@ -158,37 +242,88 @@ func (n *actionsNotifier) CreateIssueComment(ctx context.Context, doer *user_mod
) {
ctx = withMethod(ctx, "CreateIssueComment")
permission, _ := access_model.GetUserRepoPermission(ctx, repo, doer)
if issue.IsPull {
if err := issue.LoadPullRequest(ctx); err != nil {
notifyIssueCommentChange(ctx, doer, comment, "", webhook_module.HookEventPullRequestComment, api.HookIssueCommentCreated)
return
}
notifyIssueCommentChange(ctx, doer, comment, "", webhook_module.HookEventIssueComment, api.HookIssueCommentCreated)
}
func (n *actionsNotifier) UpdateComment(ctx context.Context, doer *user_model.User, c *issues_model.Comment, oldContent string) {
ctx = withMethod(ctx, "UpdateComment")
if err := c.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return
}
if c.Issue.IsPull {
notifyIssueCommentChange(ctx, doer, c, oldContent, webhook_module.HookEventPullRequestComment, api.HookIssueCommentEdited)
return
}
notifyIssueCommentChange(ctx, doer, c, oldContent, webhook_module.HookEventIssueComment, api.HookIssueCommentEdited)
}
func (n *actionsNotifier) DeleteComment(ctx context.Context, doer *user_model.User, comment *issues_model.Comment) {
ctx = withMethod(ctx, "DeleteComment")
if err := comment.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return
}
if comment.Issue.IsPull {
notifyIssueCommentChange(ctx, doer, comment, "", webhook_module.HookEventPullRequestComment, api.HookIssueCommentDeleted)
return
}
notifyIssueCommentChange(ctx, doer, comment, "", webhook_module.HookEventIssueComment, api.HookIssueCommentDeleted)
}
func notifyIssueCommentChange(ctx context.Context, doer *user_model.User, comment *issues_model.Comment, oldContent string, event webhook_module.HookEventType, action api.HookIssueCommentAction) {
if err := comment.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return
}
if err := comment.Issue.LoadAttributes(ctx); err != nil {
log.Error("LoadAttributes: %v", err)
return
}
permission, _ := access_model.GetUserRepoPermission(ctx, comment.Issue.Repo, doer)
payload := &api.IssueCommentPayload{
Action: action,
Issue: convert.ToAPIIssue(ctx, comment.Issue),
Comment: convert.ToAPIComment(ctx, comment.Issue.Repo, comment),
Repository: convert.ToRepo(ctx, comment.Issue.Repo, permission),
Sender: convert.ToUser(ctx, doer, nil),
IsPull: comment.Issue.IsPull,
}
if action == api.HookIssueCommentEdited {
payload.Changes = &api.ChangesPayload{
Body: &api.ChangesFromPayload{
From: oldContent,
},
}
}
if comment.Issue.IsPull {
if err := comment.Issue.LoadPullRequest(ctx); err != nil {
log.Error("LoadPullRequest: %v", err)
return
}
newNotifyInputFromIssue(issue, webhook_module.HookEventPullRequestComment).
newNotifyInputFromIssue(comment.Issue, event).
WithDoer(doer).
WithPayload(&api.IssueCommentPayload{
Action: api.HookIssueCommentCreated,
Issue: convert.ToAPIIssue(ctx, issue),
Comment: convert.ToAPIComment(ctx, repo, comment),
Repository: convert.ToRepo(ctx, repo, permission),
Sender: convert.ToUser(ctx, doer, nil),
IsPull: true,
}).
WithPullRequest(issue.PullRequest).
WithPayload(payload).
WithPullRequest(comment.Issue.PullRequest).
Notify(ctx)
return
}
newNotifyInputFromIssue(issue, webhook_module.HookEventIssueComment).
newNotifyInputFromIssue(comment.Issue, event).
WithDoer(doer).
WithPayload(&api.IssueCommentPayload{
Action: api.HookIssueCommentCreated,
Issue: convert.ToAPIIssue(ctx, issue),
Comment: convert.ToAPIComment(ctx, repo, comment),
Repository: convert.ToRepo(ctx, repo, permission),
Sender: convert.ToUser(ctx, doer, nil),
IsPull: false,
}).
WithPayload(payload).
Notify(ctx)
}
@@ -305,6 +440,39 @@ func (n *actionsNotifier) PullRequestReview(ctx context.Context, pr *issues_mode
}).Notify(ctx)
}
func (n *actionsNotifier) PullRequestReviewRequest(ctx context.Context, doer *user_model.User, issue *issues_model.Issue, reviewer *user_model.User, isRequest bool, comment *issues_model.Comment) {
if !issue.IsPull {
log.Warn("PullRequestReviewRequest: issue is not a pull request: %v", issue.ID)
return
}
ctx = withMethod(ctx, "PullRequestReviewRequest")
permission, _ := access_model.GetUserRepoPermission(ctx, issue.Repo, doer)
if err := issue.LoadPullRequest(ctx); err != nil {
log.Error("LoadPullRequest failed: %v", err)
return
}
var action api.HookIssueAction
if isRequest {
action = api.HookIssueReviewRequested
} else {
action = api.HookIssueReviewRequestRemoved
}
newNotifyInputFromIssue(issue, webhook_module.HookEventPullRequestReviewRequest).
WithDoer(doer).
WithPayload(&api.PullRequestPayload{
Action: action,
Index: issue.Index,
PullRequest: convert.ToAPIPullRequest(ctx, issue.PullRequest, nil),
RequestedReviewer: convert.ToUser(ctx, reviewer, nil),
Repository: convert.ToRepo(ctx, issue.Repo, permission),
Sender: convert.ToUser(ctx, doer, nil),
}).
WithPullRequest(issue.PullRequest).
Notify(ctx)
}
func (*actionsNotifier) MergePullRequest(ctx context.Context, doer *user_model.User, pr *issues_model.PullRequest) {
ctx = withMethod(ctx, "MergePullRequest")
@@ -347,6 +515,12 @@ func (*actionsNotifier) MergePullRequest(ctx context.Context, doer *user_model.U
}
func (n *actionsNotifier) PushCommits(ctx context.Context, pusher *user_model.User, repo *repo_model.Repository, opts *repository.PushUpdateOptions, commits *repository.PushCommits) {
commitID, _ := git.NewIDFromString(opts.NewCommitID)
if commitID.IsZero() {
log.Trace("new commitID is empty")
return
}
ctx = withMethod(ctx, "PushCommits")
apiPusher := convert.ToUser(ctx, pusher, nil)
@@ -379,9 +553,9 @@ func (n *actionsNotifier) CreateRef(ctx context.Context, pusher *user_model.User
apiRepo := convert.ToRepo(ctx, repo, access_model.Permission{AccessMode: perm_model.AccessModeNone})
newNotifyInput(repo, pusher, webhook_module.HookEventCreate).
WithRef(refFullName.ShortName()). // FIXME: should we use a full ref name
WithRef(refFullName.String()).
WithPayload(&api.CreatePayload{
Ref: refFullName.ShortName(),
Ref: refFullName.String(),
Sha: refID,
RefType: refFullName.RefType(),
Repo: apiRepo,
@@ -397,9 +571,8 @@ func (n *actionsNotifier) DeleteRef(ctx context.Context, pusher *user_model.User
apiRepo := convert.ToRepo(ctx, repo, access_model.Permission{AccessMode: perm_model.AccessModeNone})
newNotifyInput(repo, pusher, webhook_module.HookEventDelete).
WithRef(refFullName.ShortName()). // FIXME: should we use a full ref name
WithPayload(&api.DeletePayload{
Ref: refFullName.ShortName(),
Ref: refFullName.String(),
RefType: refFullName.RefType(),
PusherType: api.PusherTypeUser,
Repo: apiRepo,
@@ -456,6 +629,10 @@ func (n *actionsNotifier) UpdateRelease(ctx context.Context, doer *user_model.Us
}
func (n *actionsNotifier) DeleteRelease(ctx context.Context, doer *user_model.User, rel *repo_model.Release) {
if rel.IsTag {
// has sent same action in `PushCommits`, so skip it.
return
}
ctx = withMethod(ctx, "DeleteRelease")
notifyRelease(ctx, doer, rel, api.HookReleaseDeleted)
}
@@ -565,3 +742,15 @@ func (n *actionsNotifier) DeleteWikiPage(ctx context.Context, doer *user_model.U
Page: page,
}).Notify(ctx)
}
// MigrateRepository is used to detect workflows after a repository has been migrated
func (n *actionsNotifier) MigrateRepository(ctx context.Context, doer, u *user_model.User, repo *repo_model.Repository) {
ctx = withMethod(ctx, "MigrateRepository")
newNotifyInput(repo, doer, webhook_module.HookEventRepository).WithPayload(&api.RepositoryPayload{
Action: api.HookRepoCreated,
Repository: convert.ToRepo(ctx, repo, access_model.Permission{AccessMode: perm_model.AccessModeOwner}),
Organization: convert.ToUser(ctx, u, nil),
Sender: convert.ToUser(ctx, doer, nil),
}).Notify(ctx)
}
+134 -47
View File
@@ -7,9 +7,11 @@ import (
"bytes"
"context"
"fmt"
"slices"
"strings"
actions_model "code.gitea.io/gitea/models/actions"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
packages_model "code.gitea.io/gitea/models/packages"
access_model "code.gitea.io/gitea/models/perm/access"
@@ -18,8 +20,10 @@ import (
user_model "code.gitea.io/gitea/models/user"
actions_module "code.gitea.io/gitea/modules/actions"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
api "code.gitea.io/gitea/modules/structs"
webhook_module "code.gitea.io/gitea/modules/webhook"
"code.gitea.io/gitea/services/convert"
@@ -113,7 +117,13 @@ func notify(ctx context.Context, input *notifyInput) error {
log.Debug("ignore executing %v for event %v whose doer is %v", getMethod(ctx), input.Event, input.Doer.Name)
return nil
}
if input.Repo.IsEmpty || input.Repo.IsArchived {
return nil
}
if unit_model.TypeActions.UnitGlobalDisabled() {
if err := actions_model.CleanRepoScheduleTasks(ctx, input.Repo); err != nil {
log.Error("CleanRepoScheduleTasks: %v", err)
}
return nil
}
if err := input.Repo.LoadUnits(ctx); err != nil {
@@ -122,19 +132,22 @@ func notify(ctx context.Context, input *notifyInput) error {
return nil
}
gitRepo, err := git.OpenRepository(context.Background(), input.Repo.RepoPath())
gitRepo, err := gitrepo.OpenRepository(context.Background(), input.Repo)
if err != nil {
return fmt.Errorf("git.OpenRepository: %w", err)
}
defer gitRepo.Close()
ref := input.Ref
if input.Event == webhook_module.HookEventDelete {
// The event is deleting a reference, so it will fail to get the commit for a deleted reference.
// Set ref to empty string to fall back to the default branch.
ref = ""
if ref != input.Repo.DefaultBranch && actions_module.IsDefaultBranchWorkflow(input.Event) {
if ref != "" {
log.Warn("Event %q should only trigger workflows on the default branch, but its ref is %q. Will fall back to the default branch",
input.Event, ref)
}
ref = input.Repo.DefaultBranch
}
if ref == "" {
log.Warn("Ref of event %q is empty, will fall back to the default branch", input.Event)
ref = input.Repo.DefaultBranch
}
@@ -144,25 +157,38 @@ func notify(ctx context.Context, input *notifyInput) error {
return fmt.Errorf("gitRepo.GetCommit: %w", err)
}
if skipWorkflows(input, commit) {
return nil
}
var detectedWorkflows []*actions_module.DetectedWorkflow
actionsConfig := input.Repo.MustGetUnit(ctx, unit_model.TypeActions).ActionsConfig()
workflows, schedules, err := actions_module.DetectWorkflows(gitRepo, commit, input.Event, input.Payload)
shouldDetectSchedules := input.Event == webhook_module.HookEventPush && git.RefName(input.Ref).BranchName() == input.Repo.DefaultBranch
workflows, schedules, err := actions_module.DetectWorkflows(gitRepo, commit,
input.Event,
input.Payload,
shouldDetectSchedules,
)
if err != nil {
return fmt.Errorf("DetectWorkflows: %w", err)
}
if len(workflows) == 0 {
log.Trace("repo %s with commit %s couldn't find workflows", input.Repo.RepoPath(), commit.ID)
} else {
for _, wf := range workflows {
if actionsConfig.IsWorkflowDisabled(wf.EntryName) {
log.Trace("repo %s has disable workflows %s", input.Repo.RepoPath(), wf.EntryName)
continue
}
log.Trace("repo %s with commit %s event %s find %d workflows and %d schedules",
input.Repo.RepoPath(),
commit.ID,
input.Event,
len(workflows),
len(schedules),
)
if wf.TriggerEvent != actions_module.GithubEventPullRequestTarget {
detectedWorkflows = append(detectedWorkflows, wf)
}
for _, wf := range workflows {
if actionsConfig.IsWorkflowDisabled(wf.EntryName) {
log.Trace("repo %s has disable workflows %s", input.Repo.RepoPath(), wf.EntryName)
continue
}
if wf.TriggerEvent.Name != actions_module.GithubEventPullRequestTarget {
detectedWorkflows = append(detectedWorkflows, wf)
}
}
@@ -173,7 +199,7 @@ func notify(ctx context.Context, input *notifyInput) error {
if err != nil {
return fmt.Errorf("gitRepo.GetCommit: %w", err)
}
baseWorkflows, _, err := actions_module.DetectWorkflows(gitRepo, baseCommit, input.Event, input.Payload)
baseWorkflows, _, err := actions_module.DetectWorkflows(gitRepo, baseCommit, input.Event, input.Payload, false)
if err != nil {
return fmt.Errorf("DetectWorkflows: %w", err)
}
@@ -181,20 +207,45 @@ func notify(ctx context.Context, input *notifyInput) error {
log.Trace("repo %s with commit %s couldn't find pull_request_target workflows", input.Repo.RepoPath(), baseCommit.ID)
} else {
for _, wf := range baseWorkflows {
if wf.TriggerEvent == actions_module.GithubEventPullRequestTarget {
if wf.TriggerEvent.Name == actions_module.GithubEventPullRequestTarget {
detectedWorkflows = append(detectedWorkflows, wf)
}
}
}
}
if err := handleSchedules(ctx, schedules, commit, input); err != nil {
return err
if shouldDetectSchedules {
if err := handleSchedules(ctx, schedules, commit, input, ref); err != nil {
return err
}
}
return handleWorkflows(ctx, detectedWorkflows, commit, input, ref)
}
func skipWorkflows(input *notifyInput, commit *git.Commit) bool {
// skip workflow runs with a configured skip-ci string in commit message or pr title if the event is push or pull_request(_sync)
// https://docs.github.com/en/actions/managing-workflow-runs/skipping-workflow-runs
skipWorkflowEvents := []webhook_module.HookEventType{
webhook_module.HookEventPush,
webhook_module.HookEventPullRequest,
webhook_module.HookEventPullRequestSync,
}
if slices.Contains(skipWorkflowEvents, input.Event) {
for _, s := range setting.Actions.SkipWorkflowStrings {
if input.PullRequest != nil && strings.Contains(input.PullRequest.Issue.Title, s) {
log.Debug("repo %s: skipped run for pr %v because of %s string", input.Repo.RepoPath(), input.PullRequest.Issue.ID, s)
return true
}
if strings.Contains(commit.CommitMessage, s) {
log.Debug("repo %s with commit %s: skipped run because of %s string", input.Repo.RepoPath(), commit.ID, s)
return true
}
}
}
return false
}
func handleWorkflows(
ctx context.Context,
detectedWorkflows []*actions_module.DetectedWorkflow,
@@ -239,7 +290,7 @@ func handleWorkflows(
IsForkPullRequest: isForkPullRequest,
Event: input.Event,
EventPayload: string(p),
TriggerEvent: dwf.TriggerEvent,
TriggerEvent: dwf.TriggerEvent.Name,
Status: actions_model.StatusWaiting,
}
if need, err := ifNeedApproval(ctx, run, input.Repo, input.Doer); err != nil {
@@ -249,22 +300,34 @@ func handleWorkflows(
run.NeedApproval = need
}
jobs, err := jobparser.Parse(dwf.Content)
if err := run.LoadAttributes(ctx); err != nil {
log.Error("LoadAttributes: %v", err)
continue
}
vars, err := actions_model.GetVariablesOfRun(ctx, run)
if err != nil {
log.Error("GetVariablesOfRun: %v", err)
continue
}
jobs, err := jobparser.Parse(dwf.Content, jobparser.WithVars(vars))
if err != nil {
log.Error("jobparser.Parse: %v", err)
continue
}
// cancel running jobs if the event is push
if run.Event == webhook_module.HookEventPush {
// cancel running jobs of the same workflow
if err := actions_model.CancelRunningJobs(
// cancel running jobs if the event is push or pull_request_sync
if run.Event == webhook_module.HookEventPush ||
run.Event == webhook_module.HookEventPullRequestSync {
if err := actions_model.CancelPreviousJobs(
ctx,
run.RepoID,
run.Ref,
run.WorkflowID,
run.Event,
); err != nil {
log.Error("CancelRunningJobs: %v", err)
log.Error("CancelPreviousJobs: %v", err)
}
}
@@ -273,7 +336,7 @@ func handleWorkflows(
continue
}
alljobs, _, err := actions_model.FindRunJobs(ctx, actions_model.FindRunJobOptions{RunID: run.ID})
alljobs, err := db.Find[actions_model.ActionRunJob](ctx, actions_model.FindRunJobOptions{RunID: run.ID})
if err != nil {
log.Error("FindRunJobs: %v", err)
continue
@@ -352,7 +415,7 @@ func ifNeedApproval(ctx context.Context, run *actions_model.ActionRun, repo *rep
}
// don't need approval if the user has been approved before
if count, err := actions_model.CountRuns(ctx, actions_model.FindRunOptions{
if count, err := db.Count[actions_model.ActionRun](ctx, actions_model.FindRunOptions{
RepoID: repo.ID,
TriggerUserID: user.ID,
Approved: true,
@@ -373,6 +436,7 @@ func handleSchedules(
detectedWorkflows []*actions_module.DetectedWorkflow,
commit *git.Commit,
input *notifyInput,
ref string,
) error {
branch, err := commit.GetBranchName()
if err != nil {
@@ -383,12 +447,12 @@ func handleSchedules(
return nil
}
if count, err := actions_model.CountSchedules(ctx, actions_model.FindScheduleOptions{RepoID: input.Repo.ID}); err != nil {
if count, err := db.Count[actions_model.ActionSchedule](ctx, actions_model.FindScheduleOptions{RepoID: input.Repo.ID}); err != nil {
log.Error("CountSchedules: %v", err)
return err
} else if count > 0 {
if err := actions_model.DeleteScheduleTaskByRepo(ctx, input.Repo.ID); err != nil {
log.Error("DeleteCronTaskByRepo: %v", err)
if err := actions_model.CleanRepoScheduleTasks(ctx, input.Repo); err != nil {
log.Error("CleanRepoScheduleTasks: %v", err)
}
}
@@ -422,28 +486,51 @@ func handleSchedules(
OwnerID: input.Repo.OwnerID,
WorkflowID: dwf.EntryName,
TriggerUserID: input.Doer.ID,
Ref: input.Ref,
Ref: ref,
CommitSHA: commit.ID.String(),
Event: input.Event,
EventPayload: string(p),
Specs: schedules,
Content: dwf.Content,
}
// cancel running jobs if the event is push
if run.Event == webhook_module.HookEventPush {
// cancel running jobs of the same workflow
if err := actions_model.CancelRunningJobs(
ctx,
run.RepoID,
run.Ref,
run.WorkflowID,
); err != nil {
log.Error("CancelRunningJobs: %v", err)
}
}
crons = append(crons, run)
}
return actions_model.CreateScheduleTask(ctx, crons)
}
// DetectAndHandleSchedules detects the schedule workflows on the default branch and create schedule tasks
func DetectAndHandleSchedules(ctx context.Context, repo *repo_model.Repository) error {
if repo.IsEmpty || repo.IsArchived {
return nil
}
gitRepo, err := gitrepo.OpenRepository(context.Background(), repo)
if err != nil {
return fmt.Errorf("git.OpenRepository: %w", err)
}
defer gitRepo.Close()
// Only detect schedule workflows on the default branch
commit, err := gitRepo.GetCommit(repo.DefaultBranch)
if err != nil {
return fmt.Errorf("gitRepo.GetCommit: %w", err)
}
scheduleWorkflows, err := actions_module.DetectScheduledWorkflows(gitRepo, commit)
if err != nil {
return fmt.Errorf("detect schedule workflows: %w", err)
}
if len(scheduleWorkflows) == 0 {
return nil
}
// We need a notifyInput to call handleSchedules
// Here we use the commit author as the Doer of the notifyInput
commitUser, err := user_model.GetUserByEmail(ctx, commit.Author.Email)
if err != nil {
return fmt.Errorf("get user by email: %w", err)
}
notifyInput := newNotifyInput(repo, commitUser, webhook_module.HookEventSchedule)
return handleSchedules(ctx, scheduleWorkflows, commit, notifyInput, repo.DefaultBranch)
}
+38
View File
@@ -0,0 +1,38 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
actions_model "code.gitea.io/gitea/models/actions"
"code.gitea.io/gitea/modules/container"
)
// GetAllRerunJobs get all jobs that need to be rerun when job should be rerun
func GetAllRerunJobs(job *actions_model.ActionRunJob, allJobs []*actions_model.ActionRunJob) []*actions_model.ActionRunJob {
rerunJobs := []*actions_model.ActionRunJob{job}
rerunJobsIDSet := make(container.Set[string])
rerunJobsIDSet.Add(job.JobID)
for {
found := false
for _, j := range allJobs {
if rerunJobsIDSet.Contains(j.JobID) {
continue
}
for _, need := range j.Needs {
if rerunJobsIDSet.Contains(need) {
found = true
rerunJobs = append(rerunJobs, j)
rerunJobsIDSet.Add(j.JobID)
break
}
}
}
if !found {
break
}
}
return rerunJobs
}
+48
View File
@@ -0,0 +1,48 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"testing"
actions_model "code.gitea.io/gitea/models/actions"
"github.com/stretchr/testify/assert"
)
func TestGetAllRerunJobs(t *testing.T) {
job1 := &actions_model.ActionRunJob{JobID: "job1"}
job2 := &actions_model.ActionRunJob{JobID: "job2", Needs: []string{"job1"}}
job3 := &actions_model.ActionRunJob{JobID: "job3", Needs: []string{"job2"}}
job4 := &actions_model.ActionRunJob{JobID: "job4", Needs: []string{"job2", "job3"}}
jobs := []*actions_model.ActionRunJob{job1, job2, job3, job4}
testCases := []struct {
job *actions_model.ActionRunJob
rerunJobs []*actions_model.ActionRunJob
}{
{
job1,
[]*actions_model.ActionRunJob{job1, job2, job3, job4},
},
{
job2,
[]*actions_model.ActionRunJob{job2, job3, job4},
},
{
job3,
[]*actions_model.ActionRunJob{job3, job4},
},
{
job4,
[]*actions_model.ActionRunJob{job4},
},
}
for _, tc := range testCases {
rerunJobs := GetAllRerunJobs(tc.job, jobs)
assert.ElementsMatch(t, tc.rerunJobs, rerunJobs)
}
}
+19 -4
View File
@@ -10,6 +10,7 @@ import (
actions_model "code.gitea.io/gitea/models/actions"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
@@ -54,18 +55,31 @@ func startTasks(ctx context.Context) error {
// cancel running jobs if the event is push
if row.Schedule.Event == webhook_module.HookEventPush {
// cancel running jobs of the same workflow
if err := actions_model.CancelRunningJobs(
if err := actions_model.CancelPreviousJobs(
ctx,
row.RepoID,
row.Schedule.Ref,
row.Schedule.WorkflowID,
webhook_module.HookEventSchedule,
); err != nil {
log.Error("CancelRunningJobs: %v", err)
log.Error("CancelPreviousJobs: %v", err)
}
}
cfg := row.Repo.MustGetUnit(ctx, unit.TypeActions).ActionsConfig()
if cfg.IsWorkflowDisabled(row.Schedule.WorkflowID) {
if row.Repo.IsArchived {
// Skip if the repo is archived
continue
}
cfg, err := row.Repo.GetUnit(ctx, unit.TypeActions)
if err != nil {
if repo_model.IsErrUnitTypeNotExist(err) {
// Skip the actions unit of this repo is disabled.
continue
}
return fmt.Errorf("GetUnit: %w", err)
}
if cfg.ActionsConfig().IsWorkflowDisabled(row.Schedule.WorkflowID) {
continue
}
@@ -113,6 +127,7 @@ func CreateScheduleTask(ctx context.Context, cron *actions_model.ActionSchedule)
CommitSHA: cron.CommitSHA,
Event: cron.Event,
EventPayload: cron.EventPayload,
TriggerEvent: string(webhook_module.HookEventSchedule),
ScheduleID: cron.ID,
Status: actions_model.StatusWaiting,
}
+100
View File
@@ -0,0 +1,100 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"regexp"
"strings"
actions_model "code.gitea.io/gitea/models/actions"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/util"
secret_service "code.gitea.io/gitea/services/secrets"
)
func CreateVariable(ctx context.Context, ownerID, repoID int64, name, data string) (*actions_model.ActionVariable, error) {
if err := secret_service.ValidateName(name); err != nil {
return nil, err
}
if err := envNameCIRegexMatch(name); err != nil {
return nil, err
}
v, err := actions_model.InsertVariable(ctx, ownerID, repoID, name, util.ReserveLineBreakForTextarea(data))
if err != nil {
return nil, err
}
return v, nil
}
func UpdateVariable(ctx context.Context, variableID int64, name, data string) (bool, error) {
if err := secret_service.ValidateName(name); err != nil {
return false, err
}
if err := envNameCIRegexMatch(name); err != nil {
return false, err
}
return actions_model.UpdateVariable(ctx, &actions_model.ActionVariable{
ID: variableID,
Name: strings.ToUpper(name),
Data: util.ReserveLineBreakForTextarea(data),
})
}
func DeleteVariableByID(ctx context.Context, variableID int64) error {
return actions_model.DeleteVariable(ctx, variableID)
}
func DeleteVariableByName(ctx context.Context, ownerID, repoID int64, name string) error {
if err := secret_service.ValidateName(name); err != nil {
return err
}
if err := envNameCIRegexMatch(name); err != nil {
return err
}
v, err := GetVariable(ctx, actions_model.FindVariablesOpts{
OwnerID: ownerID,
RepoID: repoID,
Name: name,
})
if err != nil {
return err
}
return actions_model.DeleteVariable(ctx, v.ID)
}
func GetVariable(ctx context.Context, opts actions_model.FindVariablesOpts) (*actions_model.ActionVariable, error) {
vars, err := actions_model.FindVariables(ctx, opts)
if err != nil {
return nil, err
}
if len(vars) != 1 {
return nil, util.NewNotExistErrorf("variable not found")
}
return vars[0], nil
}
// some regular expression of `variables` and `secrets`
// reference to:
// https://docs.github.com/en/actions/learn-github-actions/variables#naming-conventions-for-configuration-variables
// https://docs.github.com/en/actions/security-guides/encrypted-secrets#naming-your-secrets
var (
forbiddenEnvNameCIRx = regexp.MustCompile("(?i)^CI")
)
func envNameCIRegexMatch(name string) error {
if forbiddenEnvNameCIRx.MatchString(name) {
log.Error("Env Name cannot be ci")
return util.NewInvalidArgumentErrorf("env name cannot be ci")
}
return nil
}
+38 -44
View File
@@ -7,6 +7,7 @@ import (
"context"
"fmt"
"os"
"strconv"
"strings"
issues_model "code.gitea.io/gitea/models/issues"
@@ -21,24 +22,21 @@ import (
// ProcReceive handle proc receive work
func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.Repository, opts *private.HookOptions) ([]private.HookProcReceiveRefResult, error) {
// TODO: Add more options?
var (
topicBranch string
title string
description string
forcePush bool
)
results := make([]private.HookProcReceiveRefResult, 0, len(opts.OldCommitIDs))
topicBranch := opts.GitPushOptions["topic"]
forcePush, _ := strconv.ParseBool(opts.GitPushOptions["force-push"])
title := strings.TrimSpace(opts.GitPushOptions["title"])
description := strings.TrimSpace(opts.GitPushOptions["description"]) // TODO: Add more options?
objectFormat := git.ObjectFormatFromName(repo.ObjectFormatName)
userName := strings.ToLower(opts.UserName)
ownerName := repo.OwnerName
repoName := repo.Name
topicBranch = opts.GitPushOptions["topic"]
_, forcePush = opts.GitPushOptions["force-push"]
pusher, err := user_model.GetUserByID(ctx, opts.UserID)
if err != nil {
return nil, fmt.Errorf("failed to get user. Error: %w", err)
}
for i := range opts.OldCommitIDs {
if opts.NewCommitIDs[i] == git.EmptySHA {
if opts.NewCommitIDs[i] == objectFormat.EmptyObjectID().String() {
results = append(results, private.HookProcReceiveRefResult{
OriginalRef: opts.RefFullNames[i],
OldOID: opts.OldCommitIDs[i],
@@ -79,9 +77,6 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
continue
}
var headBranch string
userName := strings.ToLower(opts.UserName)
if len(curentTopicBranch) == 0 {
curentTopicBranch = topicBranch
}
@@ -89,6 +84,7 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
// because different user maybe want to use same topic,
// So it's better to make sure the topic branch name
// has user name prefix
var headBranch string
if !strings.HasPrefix(curentTopicBranch, userName+"/") {
headBranch = userName + "/" + curentTopicBranch
} else {
@@ -98,26 +94,26 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
pr, err := issues_model.GetUnmergedPullRequest(ctx, repo.ID, repo.ID, headBranch, baseBranchName, issues_model.PullRequestFlowAGit)
if err != nil {
if !issues_model.IsErrPullRequestNotExist(err) {
return nil, fmt.Errorf("Failed to get unmerged agit flow pull request in repository: %s/%s Error: %w", ownerName, repoName, err)
return nil, fmt.Errorf("failed to get unmerged agit flow pull request in repository: %s Error: %w", repo.FullName(), err)
}
var commit *git.Commit
if title == "" || description == "" {
commit, err = gitRepo.GetCommit(opts.NewCommitIDs[i])
if err != nil {
return nil, fmt.Errorf("failed to get commit %s in repository: %s Error: %w", opts.NewCommitIDs[i], repo.FullName(), err)
}
}
// create a new pull request
if len(title) == 0 {
var has bool
title, has = opts.GitPushOptions["title"]
if !has || len(title) == 0 {
commit, err := gitRepo.GetCommit(opts.NewCommitIDs[i])
if err != nil {
return nil, fmt.Errorf("Failed to get commit %s in repository: %s/%s Error: %w", opts.NewCommitIDs[i], ownerName, repoName, err)
}
title = strings.Split(commit.CommitMessage, "\n")[0]
}
description = opts.GitPushOptions["description"]
if title == "" {
title = strings.Split(commit.CommitMessage, "\n")[0]
}
pusher, err := user_model.GetUserByID(ctx, opts.UserID)
if err != nil {
return nil, fmt.Errorf("Failed to get user. Error: %w", err)
if description == "" {
_, description, _ = strings.Cut(commit.CommitMessage, "\n\n")
}
if description == "" {
description = title
}
prIssue := &issues_model.Issue{
@@ -151,7 +147,7 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
results = append(results, private.HookProcReceiveRefResult{
Ref: pr.GetGitRefName(),
OriginalRef: opts.RefFullNames[i],
OldOID: git.EmptySHA,
OldOID: objectFormat.EmptyObjectID().String(),
NewOID: opts.NewCommitIDs[i],
})
continue
@@ -159,12 +155,12 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
// update exist pull request
if err := pr.LoadBaseRepo(ctx); err != nil {
return nil, fmt.Errorf("Unable to load base repository for PR[%d] Error: %w", pr.ID, err)
return nil, fmt.Errorf("unable to load base repository for PR[%d] Error: %w", pr.ID, err)
}
oldCommitID, err := gitRepo.GetRefCommitID(pr.GetGitRefName())
if err != nil {
return nil, fmt.Errorf("Unable to get ref commit id in base repository for PR[%d] Error: %w", pr.ID, err)
return nil, fmt.Errorf("unable to get ref commit id in base repository for PR[%d] Error: %w", pr.ID, err)
}
if oldCommitID == opts.NewCommitIDs[i] {
@@ -178,9 +174,11 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
}
if !forcePush {
output, _, err := git.NewCommand(ctx, "rev-list", "--max-count=1").AddDynamicArguments(oldCommitID, "^"+opts.NewCommitIDs[i]).RunStdString(&git.RunOpts{Dir: repo.RepoPath(), Env: os.Environ()})
output, _, err := git.NewCommand(ctx, "rev-list", "--max-count=1").
AddDynamicArguments(oldCommitID, "^"+opts.NewCommitIDs[i]).
RunStdString(&git.RunOpts{Dir: repo.RepoPath(), Env: os.Environ()})
if err != nil {
return nil, fmt.Errorf("Fail to detect force push: %w", err)
return nil, fmt.Errorf("failed to detect force push: %w", err)
} else if len(output) > 0 {
results = append(results, private.HookProcReceiveRefResult{
OriginalRef: opts.RefFullNames[i],
@@ -194,17 +192,13 @@ func ProcReceive(ctx context.Context, repo *repo_model.Repository, gitRepo *git.
pr.HeadCommitID = opts.NewCommitIDs[i]
if err = pull_service.UpdateRef(ctx, pr); err != nil {
return nil, fmt.Errorf("Failed to update pull ref. Error: %w", err)
return nil, fmt.Errorf("failed to update pull ref. Error: %w", err)
}
pull_service.AddToTaskQueue(ctx, pr)
pusher, err := user_model.GetUserByID(ctx, opts.UserID)
if err != nil {
return nil, fmt.Errorf("Failed to get user. Error: %w", err)
}
err = pr.LoadIssue(ctx)
if err != nil {
return nil, fmt.Errorf("Failed to load pull issue. Error: %w", err)
return nil, fmt.Errorf("failed to load pull issue. Error: %w", err)
}
comment, err := pull_service.CreatePushPullComment(ctx, pusher, pr, oldCommitID, opts.NewCommitIDs[i])
if err == nil && comment != nil {
+1 -2
View File
@@ -7,7 +7,6 @@ import (
"context"
"code.gitea.io/gitea/models"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
)
@@ -27,5 +26,5 @@ func DeleteDeployKey(ctx context.Context, doer *user_model.User, id int64) error
return err
}
return asymkey_model.RewriteAllPublicKeys(ctx)
return RewriteAllPublicKeys(ctx)
}
+21 -6
View File
@@ -13,8 +13,10 @@ import (
"code.gitea.io/gitea/models/db"
git_model "code.gitea.io/gitea/models/git"
issues_model "code.gitea.io/gitea/models/issues"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/process"
"code.gitea.io/gitea/modules/setting"
@@ -143,7 +145,10 @@ Loop:
case always:
break Loop
case pubkey:
keys, err := asymkey_model.ListGPGKeys(ctx, u.ID, db.ListOptions{})
keys, err := db.Find[asymkey_model.GPGKey](ctx, asymkey_model.FindGPGKeyOptions{
OwnerID: u.ID,
IncludeSubKeys: true,
})
if err != nil {
return false, "", nil, err
}
@@ -164,7 +169,8 @@ Loop:
}
// SignWikiCommit determines if we should sign the commits to this repository wiki
func SignWikiCommit(ctx context.Context, repoWikiPath string, u *user_model.User) (bool, string, *git.Signature, error) {
func SignWikiCommit(ctx context.Context, repo *repo_model.Repository, u *user_model.User) (bool, string, *git.Signature, error) {
repoWikiPath := repo.WikiPath()
rules := signingModeFromStrings(setting.Repository.Signing.Wiki)
signingKey, sig := SigningKey(ctx, repoWikiPath)
if signingKey == "" {
@@ -179,7 +185,10 @@ Loop:
case always:
break Loop
case pubkey:
keys, err := asymkey_model.ListGPGKeys(ctx, u.ID, db.ListOptions{})
keys, err := db.Find[asymkey_model.GPGKey](ctx, asymkey_model.FindGPGKeyOptions{
OwnerID: u.ID,
IncludeSubKeys: true,
})
if err != nil {
return false, "", nil, err
}
@@ -195,7 +204,7 @@ Loop:
return false, "", nil, &ErrWontSign{twofa}
}
case parentSigned:
gitRepo, err := git.OpenRepository(ctx, repoWikiPath)
gitRepo, err := gitrepo.OpenWikiRepository(ctx, repo)
if err != nil {
return false, "", nil, err
}
@@ -232,7 +241,10 @@ Loop:
case always:
break Loop
case pubkey:
keys, err := asymkey_model.ListGPGKeys(ctx, u.ID, db.ListOptions{})
keys, err := db.Find[asymkey_model.GPGKey](ctx, asymkey_model.FindGPGKeyOptions{
OwnerID: u.ID,
IncludeSubKeys: true,
})
if err != nil {
return false, "", nil, err
}
@@ -294,7 +306,10 @@ Loop:
case always:
break Loop
case pubkey:
keys, err := asymkey_model.ListGPGKeys(ctx, u.ID, db.ListOptions{})
keys, err := db.Find[asymkey_model.GPGKey](ctx, asymkey_model.FindGPGKeyOptions{
OwnerID: u.ID,
IncludeSubKeys: true,
})
if err != nil {
return false, "", nil, err
}
+3 -3
View File
@@ -33,7 +33,7 @@ func DeletePublicKey(ctx context.Context, doer *user_model.User, id int64) (err
}
defer committer.Close()
if err = asymkey_model.DeletePublicKeys(dbCtx, id); err != nil {
if _, err = db.DeleteByID[asymkey_model.PublicKey](dbCtx, id); err != nil {
return err
}
@@ -43,8 +43,8 @@ func DeletePublicKey(ctx context.Context, doer *user_model.User, id int64) (err
committer.Close()
if key.Type == asymkey_model.KeyTypePrincipal {
return asymkey_model.RewriteAllPrincipalKeys(ctx)
return RewriteAllPrincipalKeys(ctx)
}
return asymkey_model.RewriteAllPublicKeys(ctx)
return RewriteAllPublicKeys(ctx)
}
@@ -0,0 +1,79 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
)
// RewriteAllPublicKeys removes any authorized key and rewrite all keys from database again.
// Note: db.GetEngine(ctx).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPublicKeys(ctx context.Context) error {
// Don't rewrite key if internal server
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedKeysFile {
return nil
}
return asymkey_model.WithSSHOpLocker(func() error {
return rewriteAllPublicKeys(ctx)
})
}
func rewriteAllPublicKeys(ctx context.Context) error {
if setting.SSH.RootPath != "" {
// First of ensure that the RootPath is present, and if not make it with 0700 permissions
// This of course doesn't guarantee that this is the right directory for authorized_keys
// but at least if it's supposed to be this directory and it doesn't exist and we're the
// right user it will at least be created properly.
err := os.MkdirAll(setting.SSH.RootPath, 0o700)
if err != nil {
log.Error("Unable to MkdirAll(%s): %v", setting.SSH.RootPath, err)
return err
}
}
fPath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
tmpPath := fPath + ".tmp"
t, err := os.OpenFile(tmpPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return err
}
defer func() {
t.Close()
if err := util.Remove(tmpPath); err != nil {
log.Warn("Unable to remove temporary authorized keys file: %s: Error: %v", tmpPath, err)
}
}()
if setting.SSH.AuthorizedKeysBackup {
isExist, err := util.IsExist(fPath)
if err != nil {
log.Error("Unable to check if %s exists. Error: %v", fPath, err)
return err
}
if isExist {
bakPath := fmt.Sprintf("%s_%d.gitea_bak", fPath, time.Now().Unix())
if err = util.CopyFile(fPath, bakPath); err != nil {
return err
}
}
}
if err := asymkey_model.RegeneratePublicKeys(ctx, t); err != nil {
return err
}
t.Close()
return util.Rename(tmpPath, fPath)
}
@@ -0,0 +1,131 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"bufio"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
)
// This file contains functions for creating authorized_principals files
//
// There is a dependence on the database within RewriteAllPrincipalKeys & RegeneratePrincipalKeys
// The sshOpLocker is used from ssh_key_authorized_keys.go
const (
authorizedPrincipalsFile = "authorized_principals"
tplCommentPrefix = `# gitea public key`
)
// RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again.
// Note: db.GetEngine(ctx).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPrincipalKeys(ctx context.Context) error {
// Don't rewrite key if internal server
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedPrincipalsFile {
return nil
}
return asymkey_model.WithSSHOpLocker(func() error {
return rewriteAllPrincipalKeys(ctx)
})
}
func rewriteAllPrincipalKeys(ctx context.Context) error {
if setting.SSH.RootPath != "" {
// First of ensure that the RootPath is present, and if not make it with 0700 permissions
// This of course doesn't guarantee that this is the right directory for authorized_keys
// but at least if it's supposed to be this directory and it doesn't exist and we're the
// right user it will at least be created properly.
err := os.MkdirAll(setting.SSH.RootPath, 0o700)
if err != nil {
log.Error("Unable to MkdirAll(%s): %v", setting.SSH.RootPath, err)
return err
}
}
fPath := filepath.Join(setting.SSH.RootPath, authorizedPrincipalsFile)
tmpPath := fPath + ".tmp"
t, err := os.OpenFile(tmpPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return err
}
defer func() {
t.Close()
os.Remove(tmpPath)
}()
if setting.SSH.AuthorizedPrincipalsBackup {
isExist, err := util.IsExist(fPath)
if err != nil {
log.Error("Unable to check if %s exists. Error: %v", fPath, err)
return err
}
if isExist {
bakPath := fmt.Sprintf("%s_%d.gitea_bak", fPath, time.Now().Unix())
if err = util.CopyFile(fPath, bakPath); err != nil {
return err
}
}
}
if err := regeneratePrincipalKeys(ctx, t); err != nil {
return err
}
t.Close()
return util.Rename(tmpPath, fPath)
}
func regeneratePrincipalKeys(ctx context.Context, t io.StringWriter) error {
if err := db.GetEngine(ctx).Where("type = ?", asymkey_model.KeyTypePrincipal).Iterate(new(asymkey_model.PublicKey), func(idx int, bean any) (err error) {
_, err = t.WriteString((bean.(*asymkey_model.PublicKey)).AuthorizedString())
return err
}); err != nil {
return err
}
fPath := filepath.Join(setting.SSH.RootPath, authorizedPrincipalsFile)
isExist, err := util.IsExist(fPath)
if err != nil {
log.Error("Unable to check if %s exists. Error: %v", fPath, err)
return err
}
if isExist {
f, err := os.Open(fPath)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, tplCommentPrefix) {
scanner.Scan()
continue
}
_, err = t.WriteString(line + "\n")
if err != nil {
return err
}
}
if err = scanner.Err(); err != nil {
return fmt.Errorf("regeneratePrincipalKeys scan: %w", err)
}
}
return nil
}
+54
View File
@@ -0,0 +1,54 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
)
// AddPrincipalKey adds new principal to database and authorized_principals file.
func AddPrincipalKey(ctx context.Context, ownerID int64, content string, authSourceID int64) (*asymkey_model.PublicKey, error) {
dbCtx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
// Principals cannot be duplicated.
has, err := db.GetEngine(dbCtx).
Where("content = ? AND type = ?", content, asymkey_model.KeyTypePrincipal).
Get(new(asymkey_model.PublicKey))
if err != nil {
return nil, err
} else if has {
return nil, asymkey_model.ErrKeyAlreadyExist{
Content: content,
}
}
key := &asymkey_model.PublicKey{
OwnerID: ownerID,
Name: content,
Content: content,
Mode: perm.AccessModeWrite,
Type: asymkey_model.KeyTypePrincipal,
LoginSourceID: authSourceID,
}
if err = db.Insert(dbCtx, key); err != nil {
return nil, fmt.Errorf("addKey: %w", err)
}
if err = committer.Commit(); err != nil {
return nil, err
}
committer.Close()
return key, RewriteAllPrincipalKeys(ctx)
}
+4 -1
View File
@@ -67,7 +67,10 @@ ssh-dss AAAAB3NzaC1kc3MAAACBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn59NriyboW2x/DZuYAz3ib
for i, kase := range testCases {
s.ID = int64(i) + 20
asymkey_model.AddPublicKeysBySource(db.DefaultContext, user, s, []string{kase.keyString})
keys, err := asymkey_model.ListPublicKeysBySource(db.DefaultContext, user.ID, s.ID)
keys, err := db.Find[asymkey_model.PublicKey](db.DefaultContext, asymkey_model.FindPublicKeyOptions{
OwnerID: user.ID,
LoginSourceID: s.ID,
})
assert.NoError(t, err)
if err != nil {
continue
+4 -4
View File
@@ -12,8 +12,8 @@ import (
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/storage"
"code.gitea.io/gitea/modules/upload"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/services/context/upload"
"github.com/google/uuid"
)
@@ -39,14 +39,14 @@ func NewAttachment(ctx context.Context, attach *repo_model.Attachment, file io.R
}
// UploadAttachment upload new attachment into storage and update database
func UploadAttachment(ctx context.Context, file io.Reader, allowedTypes string, fileSize int64, opts *repo_model.Attachment) (*repo_model.Attachment, error) {
func UploadAttachment(ctx context.Context, file io.Reader, allowedTypes string, fileSize int64, attach *repo_model.Attachment) (*repo_model.Attachment, error) {
buf := make([]byte, 1024)
n, _ := util.ReadAtMost(file, buf)
buf = buf[:n]
if err := upload.Verify(buf, opts.Name, allowedTypes); err != nil {
if err := upload.Verify(buf, attach.Name, allowedTypes); err != nil {
return nil, err
}
return NewAttachment(ctx, opts, io.MultiReader(bytes.NewReader(buf), file), fileSize)
return NewAttachment(ctx, attach, io.MultiReader(bytes.NewReader(buf), file), fileSize)
}
+12 -3
View File
@@ -12,11 +12,13 @@ import (
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/auth/webauthn"
gitea_context "code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/session"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"
gitea_context "code.gitea.io/gitea/services/context"
user_service "code.gitea.io/gitea/services/user"
)
// Init should be called exactly once when the application starts to allow plugins
@@ -38,6 +40,7 @@ func isContainerPath(req *http.Request) bool {
var (
gitRawOrAttachPathRe = regexp.MustCompile(`^/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/(?:(?:git-(?:(?:upload)|(?:receive))-pack$)|(?:info/refs$)|(?:HEAD$)|(?:objects/)|(?:raw/)|(?:releases/download/)|(?:attachments/))`)
lfsPathRe = regexp.MustCompile(`^/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/info/lfs/`)
archivePathRe = regexp.MustCompile(`^/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/archive/`)
)
func isGitRawOrAttachPath(req *http.Request) bool {
@@ -54,6 +57,10 @@ func isGitRawOrAttachOrLFSPath(req *http.Request) bool {
return false
}
func isArchivePath(req *http.Request) bool {
return archivePathRe.MatchString(req.URL.Path)
}
// handleSignIn clears existing session variables and stores new ones for the specified user object
func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore, user *user_model.User) {
// We need to regenerate the session...
@@ -85,8 +92,10 @@ func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore
// If the user does not have a locale set, we save the current one.
if len(user.Language) == 0 {
lc := middleware.Locale(resp, req)
user.Language = lc.Language()
if err := user_model.UpdateUserCols(req.Context(), user, "language"); err != nil {
opts := &user_service.UpdateOptions{
Language: optional.Some(lc.Language()),
}
if err := user_service.UpdateUser(req.Context(), user, opts); err != nil {
log.Error(fmt.Sprintf("Error updating user language [user: %d, locale: %s]", user.ID, user.Language))
return
}
+5 -5
View File
@@ -37,14 +37,14 @@ func TestCheckAuthToken(t *testing.T) {
})
t.Run("Expired", func(t *testing.T) {
timeutil.Set(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
timeutil.MockSet(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
at, token, err := CreateAuthTokenForUserID(db.DefaultContext, 2)
assert.NoError(t, err)
assert.NotNil(t, at)
assert.NotEmpty(t, token)
timeutil.Unset()
timeutil.MockUnset()
at2, err := CheckAuthToken(db.DefaultContext, at.ID+":"+token)
assert.ErrorIs(t, err, ErrAuthTokenExpired)
@@ -83,15 +83,15 @@ func TestCheckAuthToken(t *testing.T) {
func TestRegenerateAuthToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
timeutil.Set(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
defer timeutil.Unset()
timeutil.MockSet(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
defer timeutil.MockUnset()
at, token, err := CreateAuthTokenForUserID(db.DefaultContext, 2)
assert.NoError(t, err)
assert.NotNil(t, at)
assert.NotEmpty(t, token)
timeutil.Set(time.Date(2023, 1, 1, 0, 0, 1, 0, time.UTC))
timeutil.MockSet(time.Date(2023, 1, 1, 0, 0, 1, 0, time.UTC))
at2, token2, err := RegenerateAuthToken(db.DefaultContext, at)
assert.NoError(t, err)
+22 -2
View File
@@ -15,6 +15,7 @@ import (
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/web/middleware"
)
@@ -131,11 +132,30 @@ func (b *Basic) Verify(req *http.Request, w http.ResponseWriter, store DataStore
return nil, err
}
if skipper, ok := source.Cfg.(LocalTwoFASkipper); ok && skipper.IsSkipLocalTwoFA() {
store.GetData()["SkipLocalTwoFA"] = true
if skipper, ok := source.Cfg.(LocalTwoFASkipper); !ok || !skipper.IsSkipLocalTwoFA() {
if err := validateTOTP(req, u); err != nil {
return nil, err
}
}
log.Trace("Basic Authorization: Logged in user %-v", u)
return u, nil
}
func validateTOTP(req *http.Request, u *user_model.User) error {
twofa, err := auth_model.GetTwoFactorByUID(req.Context(), u.ID)
if err != nil {
if auth_model.IsErrTwoFactorNotEnrolled(err) {
// No 2FA enrollment for this user
return nil
}
return err
}
if ok, err := twofa.ValidateTOTP(req.Header.Get("X-Gitea-OTP")); err != nil {
return err
} else if !ok {
return util.NewInvalidArgumentErrorf("invalid provided OTP")
}
return nil
}
+4 -1
View File
@@ -12,6 +12,7 @@ import (
"strings"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
@@ -92,7 +93,9 @@ func VerifyPubKey(r *http.Request) (*asymkey_model.PublicKey, error) {
keyID := verifier.KeyId()
publicKeys, err := asymkey_model.SearchPublicKey(r.Context(), 0, keyID)
publicKeys, err := db.Find[asymkey_model.PublicKey](r.Context(), asymkey_model.FindPublicKeyOptions{
Fingerprint: keyID,
})
if err != nil {
return nil, err
}
+14 -8
View File
@@ -14,6 +14,7 @@ import (
auth_model "code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/services/auth/source/oauth2"
@@ -62,14 +63,19 @@ func (o *OAuth2) Name() string {
// representing whether the token exists or not
func parseToken(req *http.Request) (string, bool) {
_ = req.ParseForm()
// Check token.
if token := req.Form.Get("token"); token != "" {
return token, true
}
// Check access token.
if token := req.Form.Get("access_token"); token != "" {
return token, true
if !setting.DisableQueryAuthToken {
// Check token.
if token := req.Form.Get("token"); token != "" {
return token, true
}
// Check access token.
if token := req.Form.Get("access_token"); token != "" {
return token, true
}
} else if req.Form.Get("token") != "" || req.Form.Get("access_token") != "" {
log.Warn("API token sent in query string but DISABLE_QUERY_AUTH_TOKEN=true")
}
// check header token
if auHead := req.Header.Get("Authorization"); auHead != "" {
auths := strings.Fields(auHead)
@@ -127,7 +133,7 @@ func (o *OAuth2) userIDFromToken(ctx context.Context, tokenSHA string, store Dat
func (o *OAuth2) Verify(req *http.Request, w http.ResponseWriter, store DataStore, sess SessionStore) (*user_model.User, error) {
// These paths are not API paths, but we still want to check for tokens because they maybe in the API returned URLs
if !middleware.IsAPIPath(req) && !isAttachmentDownload(req) && !isAuthenticatedTokenRequest(req) &&
!isGitRawOrAttachPath(req) {
!isGitRawOrAttachPath(req) && !isArchivePath(req) {
return nil, nil
}
+2 -2
View File
@@ -10,8 +10,8 @@ import (
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/web/middleware"
gouuid "github.com/google/uuid"
@@ -161,7 +161,7 @@ func (r *ReverseProxy) newUser(req *http.Request) *user_model.User {
}
overwriteDefault := user_model.CreateUserOverwriteOptions{
IsActive: util.OptionalBoolTrue,
IsActive: optional.Some(true),
}
if err := user_model.CreateUser(req.Context(), user, &overwriteDefault); err != nil {
+9 -17
View File
@@ -4,7 +4,6 @@
package auth
import (
"context"
"net/http"
user_model "code.gitea.io/gitea/models/user"
@@ -29,40 +28,33 @@ func (s *Session) Name() string {
// object for that uid.
// Returns nil if there is no user uid stored in the session.
func (s *Session) Verify(req *http.Request, w http.ResponseWriter, store DataStore, sess SessionStore) (*user_model.User, error) {
user := SessionUser(req.Context(), sess)
if user != nil {
return user, nil
}
return nil, nil
}
// SessionUser returns the user object corresponding to the "uid" session variable.
func SessionUser(ctx context.Context, sess SessionStore) *user_model.User {
if sess == nil {
return nil
return nil, nil
}
// Get user ID
uid := sess.Get("uid")
if uid == nil {
return nil
return nil, nil
}
log.Trace("Session Authorization: Found user[%d]", uid)
id, ok := uid.(int64)
if !ok {
return nil
return nil, nil
}
// Get user object
user, err := user_model.GetUserByID(ctx, id)
user, err := user_model.GetUserByID(req.Context(), id)
if err != nil {
if !user_model.IsErrUserNotExist(err) {
log.Error("GetUserById: %v", err)
log.Error("GetUserByID: %v", err)
// Return the err as-is to keep current signed-in session, in case the err is something like context.Canceled. Otherwise non-existing user (nil, nil) will make the caller clear the signed-in session.
return nil, err
}
return nil
return nil, nil
}
log.Trace("Session Authorization: Logged in user %-v", user)
return user
return user, nil
}
+4 -1
View File
@@ -11,6 +11,7 @@ import (
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/services/auth/source/oauth2"
"code.gitea.io/gitea/services/auth/source/smtp"
@@ -85,7 +86,9 @@ func UserSignIn(ctx context.Context, username, password string) (*user_model.Use
}
}
sources, err := auth.AllActiveSources(ctx)
sources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
IsActive: optional.Some(true),
})
if err != nil {
return nil, nil, err
}
+1 -1
View File
@@ -18,7 +18,7 @@ func (source *Source) FromDB(bs []byte) error {
return nil
}
// ToDB exports an SMTPConfig to a serialized format.
// ToDB exports the config to a byte slice to be saved into database (this method is just dummy and does nothing for DB source)
func (source *Source) ToDB() ([]byte, error) {
return nil, nil
}
@@ -12,7 +12,8 @@ import (
"code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
auth_module "code.gitea.io/gitea/modules/auth"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/optional"
asymkey_service "code.gitea.io/gitea/services/asymkey"
source_service "code.gitea.io/gitea/services/auth/source"
user_service "code.gitea.io/gitea/services/user"
)
@@ -49,20 +50,17 @@ func (source *Source) Authenticate(ctx context.Context, user *user_model.User, u
}
}
if user != nil && !user.ProhibitLogin {
cols := make([]string, 0)
opts := &user_service.UpdateOptions{}
if len(source.AdminFilter) > 0 && user.IsAdmin != sr.IsAdmin {
// Change existing admin flag only if AdminFilter option is set
user.IsAdmin = sr.IsAdmin
cols = append(cols, "is_admin")
opts.IsAdmin = optional.Some(sr.IsAdmin)
}
if !user.IsAdmin && len(source.RestrictedFilter) > 0 && user.IsRestricted != sr.IsRestricted {
if !sr.IsAdmin && len(source.RestrictedFilter) > 0 && user.IsRestricted != sr.IsRestricted {
// Change existing restricted flag only if RestrictedFilter option is set
user.IsRestricted = sr.IsRestricted
cols = append(cols, "is_restricted")
opts.IsRestricted = optional.Some(sr.IsRestricted)
}
if len(cols) > 0 {
err = user_model.UpdateUserCols(ctx, user, cols...)
if err != nil {
if opts.IsAdmin.Has() || opts.IsRestricted.Has() {
if err := user_service.UpdateUser(ctx, user, opts); err != nil {
return nil, err
}
}
@@ -71,7 +69,7 @@ func (source *Source) Authenticate(ctx context.Context, user *user_model.User, u
if user != nil {
if isAttributeSSHPublicKeySet && asymkey_model.SynchronizePublicKeys(ctx, user, source.authSource, sr.SSHPublicKey) {
if err := asymkey_model.RewriteAllPublicKeys(ctx); err != nil {
if err := asymkey_service.RewriteAllPublicKeys(ctx); err != nil {
return user, err
}
}
@@ -87,8 +85,8 @@ func (source *Source) Authenticate(ctx context.Context, user *user_model.User, u
IsAdmin: sr.IsAdmin,
}
overwriteDefault := &user_model.CreateUserOverwriteOptions{
IsRestricted: util.OptionalBoolOf(sr.IsRestricted),
IsActive: util.OptionalBoolTrue,
IsRestricted: optional.Some(sr.IsRestricted),
IsActive: optional.Some(true),
}
err := user_model.CreateUser(ctx, user, overwriteDefault)
@@ -97,7 +95,7 @@ func (source *Source) Authenticate(ctx context.Context, user *user_model.User, u
}
if isAttributeSSHPublicKeySet && asymkey_model.AddPublicKeysBySource(ctx, user, source.authSource, sr.SSHPublicKey) {
if err := asymkey_model.RewriteAllPublicKeys(ctx); err != nil {
if err := asymkey_service.RewriteAllPublicKeys(ctx); err != nil {
return user, err
}
}
+23 -19
View File
@@ -15,7 +15,8 @@ import (
auth_module "code.gitea.io/gitea/modules/auth"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/optional"
asymkey_service "code.gitea.io/gitea/services/asymkey"
source_service "code.gitea.io/gitea/services/auth/source"
user_service "code.gitea.io/gitea/services/user"
)
@@ -77,7 +78,7 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error {
log.Warn("SyncExternalUsers: Cancelled at update of %s before completed update of users", source.authSource.Name)
// Rewrite authorized_keys file if LDAP Public SSH Key attribute is set and any key was added or removed
if sshKeysNeedUpdate {
err = asymkey_model.RewriteAllPublicKeys(ctx)
err = asymkey_service.RewriteAllPublicKeys(ctx)
if err != nil {
log.Error("RewriteAllPublicKeys: %v", err)
}
@@ -124,8 +125,8 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error {
IsAdmin: su.IsAdmin,
}
overwriteDefault := &user_model.CreateUserOverwriteOptions{
IsRestricted: util.OptionalBoolOf(su.IsRestricted),
IsActive: util.OptionalBoolTrue,
IsRestricted: optional.Some(su.IsRestricted),
IsActive: optional.Some(true),
}
err = user_model.CreateUser(ctx, usr, overwriteDefault)
@@ -158,23 +159,25 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error {
log.Trace("SyncExternalUsers[%s]: Updating user %s", source.authSource.Name, usr.Name)
usr.FullName = fullName
emailChanged := usr.Email != su.Mail
usr.Email = su.Mail
// Change existing admin flag only if AdminFilter option is set
if len(source.AdminFilter) > 0 {
usr.IsAdmin = su.IsAdmin
opts := &user_service.UpdateOptions{
FullName: optional.Some(fullName),
IsActive: optional.Some(true),
}
if source.AdminFilter != "" {
opts.IsAdmin = optional.Some(su.IsAdmin)
}
// Change existing restricted flag only if RestrictedFilter option is set
if !usr.IsAdmin && len(source.RestrictedFilter) > 0 {
usr.IsRestricted = su.IsRestricted
if !su.IsAdmin && source.RestrictedFilter != "" {
opts.IsRestricted = optional.Some(su.IsRestricted)
}
usr.IsActive = true
err = user_model.UpdateUser(ctx, usr, emailChanged, "full_name", "email", "is_admin", "is_restricted", "is_active")
if err != nil {
if err := user_service.UpdateUser(ctx, usr, opts); err != nil {
log.Error("SyncExternalUsers[%s]: Error updating user %s: %v", source.authSource.Name, usr.Name, err)
}
if err := user_service.ReplacePrimaryEmailAddress(ctx, usr, su.Mail); err != nil {
log.Error("SyncExternalUsers[%s]: Error updating user %s primary email %s: %v", source.authSource.Name, usr.Name, su.Mail, err)
}
}
if usr.IsUploadAvatarChanged(su.Avatar) {
@@ -193,7 +196,7 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error {
// Rewrite authorized_keys file if LDAP Public SSH Key attribute is set and any key was added or removed
if sshKeysNeedUpdate {
err = asymkey_model.RewriteAllPublicKeys(ctx)
err = asymkey_service.RewriteAllPublicKeys(ctx)
if err != nil {
log.Error("RewriteAllPublicKeys: %v", err)
}
@@ -215,9 +218,10 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error {
log.Trace("SyncExternalUsers[%s]: Deactivating user %s", source.authSource.Name, usr.Name)
usr.IsActive = false
err = user_model.UpdateUserCols(ctx, usr, "is_active")
if err != nil {
opts := &user_service.UpdateOptions{
IsActive: optional.Some(false),
}
if err := user_service.UpdateUser(ctx, usr, opts); err != nil {
log.Error("SyncExternalUsers[%s]: Error deactivating user %s: %v", source.authSource.Name, usr.Name, err)
}
}
+9 -1
View File
@@ -10,7 +10,9 @@ import (
"sync"
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/setting"
"github.com/google/uuid"
@@ -63,7 +65,13 @@ func ResetOAuth2(ctx context.Context) error {
// initOAuth2Sources is used to load and register all active OAuth2 providers
func initOAuth2Sources(ctx context.Context) error {
authSources, _ := auth.GetActiveOAuth2ProviderSources(ctx)
authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
IsActive: optional.Some(true),
LoginType: auth.OAuth2,
})
if err != nil {
return err
}
for _, source := range authSources {
oauth2Source, ok := source.Cfg.(*Source)
if !ok {
+1 -7
View File
@@ -300,7 +300,7 @@ func InitSigningKey() error {
case "HS384":
fallthrough
case "HS512":
key, err = loadSymmetricKey()
key = setting.GetGeneralTokenSigningSecret()
case "RS256":
fallthrough
case "RS384":
@@ -333,12 +333,6 @@ func InitSigningKey() error {
return nil
}
// loadSymmetricKey checks if the configured secret is valid.
// If it is not valid, it will return an error.
func loadSymmetricKey() (any, error) {
return util.Base64FixedDecode(base64.RawURLEncoding, []byte(setting.OAuth2.JWTSecretBase64), 32)
}
// loadOrCreateAsymmetricKey checks if the configured private key exists.
// If it does not exist a new random key gets generated and saved on the configured path.
func loadOrCreateAsymmetricKey() (any, error) {
+29 -21
View File
@@ -13,7 +13,9 @@ import (
"sort"
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/setting"
"github.com/markbates/goth"
@@ -57,7 +59,7 @@ func (p *AuthSourceProvider) DisplayName() string {
func (p *AuthSourceProvider) IconHTML(size int) template.HTML {
if p.iconURL != "" {
img := fmt.Sprintf(`<img class="gt-object-contain gt-mr-3" width="%d" height="%d" src="%s" alt="%s">`,
img := fmt.Sprintf(`<img class="tw-object-contain tw-mr-2" width="%d" height="%d" src="%s" alt="%s">`,
size,
size,
html.EscapeString(p.iconURL), html.EscapeString(p.DisplayName()),
@@ -80,10 +82,10 @@ func RegisterGothProvider(provider GothProvider) {
gothProviders[provider.Name()] = provider
}
// GetOAuth2Providers returns the map of unconfigured OAuth2 providers
// GetSupportedOAuth2Providers returns the map of unconfigured OAuth2 providers
// key is used as technical name (like in the callbackURL)
// values to display
func GetOAuth2Providers() []Provider {
func GetSupportedOAuth2Providers() []Provider {
providers := make([]Provider, 0, len(gothProviders))
for _, provider := range gothProviders {
@@ -95,33 +97,39 @@ func GetOAuth2Providers() []Provider {
return providers
}
// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers
// key is used as technical name (like in the callbackURL)
// values to display
func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) {
// Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type
func CreateProviderFromSource(source *auth.Source) (Provider, error) {
oauth2Cfg, ok := source.Cfg.(*Source)
if !ok {
return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg)
}
gothProv := gothProviders[oauth2Cfg.Provider]
return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil
}
authSources, err := auth.GetActiveOAuth2ProviderSources(ctx)
// GetOAuth2Providers returns the list of configured OAuth2 providers
func GetOAuth2Providers(ctx context.Context, isActive optional.Option[bool]) ([]Provider, error) {
authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
IsActive: isActive,
LoginType: auth.OAuth2,
})
if err != nil {
return nil, nil, err
return nil, err
}
var orderedKeys []string
providers := make(map[string]Provider)
providers := make([]Provider, 0, len(authSources))
for _, source := range authSources {
oauth2Cfg, ok := source.Cfg.(*Source)
if !ok {
log.Error("Invalid OAuth2 source config: %v", oauth2Cfg)
continue
provider, err := CreateProviderFromSource(source)
if err != nil {
return nil, err
}
gothProv := gothProviders[oauth2Cfg.Provider]
providers[source.Name] = &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}
orderedKeys = append(orderedKeys, source.Name)
providers = append(providers, provider)
}
sort.Strings(orderedKeys)
sort.Slice(providers, func(i, j int) bool {
return providers[i].Name() < providers[j].Name()
})
return orderedKeys, providers, nil
return providers, nil
}
// RegisterProviderWithGothic register a OAuth2 provider in goth lib
@@ -35,10 +35,10 @@ func (b *BaseProvider) IconHTML(size int) template.HTML {
case "github":
svgName = "octicon-mark-github"
}
svgHTML := svg.RenderHTML(svgName, size, "gt-mr-3")
svgHTML := svg.RenderHTML(svgName, size, "tw-mr-2")
if svgHTML == "" {
log.Error("No SVG icon for oauth2 provider %q", b.name)
svgHTML = svg.RenderHTML("gitea-openid", size, "gt-mr-3")
svgHTML = svg.RenderHTML("gitea-openid", size, "tw-mr-2")
}
return svgHTML
}
@@ -29,7 +29,7 @@ func (o *OpenIDProvider) DisplayName() string {
// IconHTML returns icon HTML for this provider
func (o *OpenIDProvider) IconHTML(size int) template.HTML {
return svg.RenderHTML("gitea-openid", size, "gt-mr-3")
return svg.RenderHTML("gitea-openid", size, "tw-mr-2")
}
// CreateGothProvider creates a GothProvider from this Provider
@@ -11,8 +11,8 @@ import (
"code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/auth/pam"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"github.com/google/uuid"
)
@@ -60,7 +60,7 @@ func (source *Source) Authenticate(ctx context.Context, user *user_model.User, u
LoginName: userName, // This is what the user typed in
}
overwriteDefault := &user_model.CreateUserOverwriteOptions{
IsActive: util.OptionalBoolTrue,
IsActive: optional.Some(true),
}
if err := user_model.CreateUser(ctx, user, overwriteDefault); err != nil {
@@ -12,6 +12,7 @@ import (
auth_model "code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/util"
)
@@ -75,7 +76,7 @@ func (source *Source) Authenticate(ctx context.Context, user *user_model.User, u
LoginName: userName,
}
overwriteDefault := &user_model.CreateUserOverwriteOptions{
IsActive: util.OptionalBoolTrue,
IsActive: optional.Some(true),
}
if err := user_model.CreateUser(ctx, user, overwriteDefault); err != nil {
+2 -2
View File
@@ -100,12 +100,12 @@ func syncGroupsToTeamsCached(ctx context.Context, user *user_model.User, orgTeam
}
if action == syncAdd && !isMember {
if err := models.AddTeamMember(ctx, team, user.ID); err != nil {
if err := models.AddTeamMember(ctx, team, user); err != nil {
log.Error("group sync: Could not add user to team: %v", err)
return err
}
} else if action == syncRemove && isMember {
if err := models.RemoveTeamMember(ctx, team, user.ID); err != nil {
if err := models.RemoveTeamMember(ctx, team, user); err != nil {
log.Error("group sync: Could not remove user from team: %v", err)
return err
}
+12 -12
View File
@@ -11,15 +11,15 @@ import (
"sync"
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/avatars"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
gitea_context "code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/services/auth/source/sspi"
gitea_context "code.gitea.io/gitea/services/context"
gouuid "github.com/google/uuid"
)
@@ -130,7 +130,10 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore,
// getConfig retrieves the SSPI configuration from login sources
func (s *SSPI) getConfig(ctx context.Context) (*sspi.Source, error) {
sources, err := auth.ActiveSources(ctx, auth.SSPI)
sources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
IsActive: optional.Some(true),
LoginType: auth.SSPI,
})
if err != nil {
return nil, err
}
@@ -163,17 +166,14 @@ func (s *SSPI) shouldAuthenticate(req *http.Request) (shouldAuth bool) {
func (s *SSPI) newUser(ctx context.Context, username string, cfg *sspi.Source) (*user_model.User, error) {
email := gouuid.New().String() + "@localhost.localdomain"
user := &user_model.User{
Name: username,
Email: email,
Passwd: gouuid.New().String(),
Language: cfg.DefaultLanguage,
UseCustomAvatar: true,
Avatar: avatars.DefaultAvatarLink(),
Name: username,
Email: email,
Language: cfg.DefaultLanguage,
}
emailNotificationPreference := user_model.EmailNotificationsDisabled
overwriteDefault := &user_model.CreateUserOverwriteOptions{
IsActive: util.OptionalBoolOf(cfg.AutoActivateUsers),
KeepEmailPrivate: util.OptionalBoolTrue,
IsActive: optional.Some(cfg.AutoActivateUsers),
KeepEmailPrivate: optional.Some(true),
EmailNotificationsPreference: &emailNotificationPreference,
}
if err := user_model.CreateUser(ctx, user, overwriteDefault); err != nil {
+1 -1
View File
@@ -15,7 +15,7 @@ import (
func SyncExternalUsers(ctx context.Context, updateExisting bool) error {
log.Trace("Doing: SyncExternalUsers")
ls, err := auth.Sources(ctx)
ls, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{})
if err != nil {
log.Error("SyncExternalUsers: %v", err)
return err
+4 -3
View File
@@ -17,6 +17,7 @@ import (
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/graceful"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/process"
@@ -111,7 +112,7 @@ func MergeScheduledPullRequest(ctx context.Context, sha string, repo *repo_model
}
func getPullRequestsByHeadSHA(ctx context.Context, sha string, repo *repo_model.Repository, filter func(*issues_model.PullRequest) bool) (map[int64]*issues_model.PullRequest, error) {
gitRepo, err := git.OpenRepository(ctx, repo.RepoPath())
gitRepo, err := gitrepo.OpenRepository(ctx, repo)
if err != nil {
return nil, err
}
@@ -190,7 +191,7 @@ func handlePull(pullID int64, sha string) {
return
}
headGitRepo, err := git.OpenRepository(ctx, pr.HeadRepo.RepoPath())
headGitRepo, err := gitrepo.OpenRepository(ctx, pr.HeadRepo)
if err != nil {
log.Error("OpenRepository %-v: %v", pr.HeadRepo, err)
return
@@ -246,7 +247,7 @@ func handlePull(pullID int64, sha string) {
return
}
baseGitRepo, err = git.OpenRepository(ctx, pr.BaseRepo.RepoPath())
baseGitRepo, err = gitrepo.OpenRepository(ctx, pr.BaseRepo)
if err != nil {
log.Error("OpenRepository %-v: %v", pr.BaseRepo, err)
return
+101
View File
@@ -0,0 +1,101 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"bytes"
"fmt"
"net"
"net/http"
"strings"
"text/template"
"time"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"
)
type routerLoggerOptions struct {
req *http.Request
Identity *string
Start *time.Time
ResponseWriter http.ResponseWriter
Ctx map[string]any
RequestID *string
}
const keyOfRequestIDInTemplate = ".RequestID"
// According to:
// TraceId: A valid trace identifier is a 16-byte array with at least one non-zero byte
// MD5 output is 16 or 32 bytes: md5-bytes is 16, md5-hex is 32
// SHA1: similar, SHA1-bytes is 20, SHA1-hex is 40.
// UUID is 128-bit, 32 hex chars, 36 ASCII chars with 4 dashes
// So, we accept a Request ID with a maximum character length of 40
const maxRequestIDByteLength = 40
func parseRequestIDFromRequestHeader(req *http.Request) string {
requestID := "-"
for _, key := range setting.Log.RequestIDHeaders {
if req.Header.Get(key) != "" {
requestID = req.Header.Get(key)
break
}
}
if len(requestID) > maxRequestIDByteLength {
requestID = fmt.Sprintf("%s...", requestID[:maxRequestIDByteLength])
}
return requestID
}
// AccessLogger returns a middleware to log access logger
func AccessLogger() func(http.Handler) http.Handler {
logger := log.GetLogger("access")
needRequestID := len(setting.Log.RequestIDHeaders) > 0 && strings.Contains(setting.Log.AccessLogTemplate, keyOfRequestIDInTemplate)
logTemplate, _ := template.New("log").Parse(setting.Log.AccessLogTemplate)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
start := time.Now()
var requestID string
if needRequestID {
requestID = parseRequestIDFromRequestHeader(req)
}
reqHost, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
reqHost = req.RemoteAddr
}
next.ServeHTTP(w, req)
rw := w.(ResponseWriter)
identity := "-"
data := middleware.GetContextData(req.Context())
if signedUser, ok := data[middleware.ContextDataKeySignedUser].(*user_model.User); ok {
identity = signedUser.Name
}
buf := bytes.NewBuffer([]byte{})
err = logTemplate.Execute(buf, routerLoggerOptions{
req: req,
Identity: &identity,
Start: &start,
ResponseWriter: rw,
Ctx: map[string]any{
"RemoteAddr": req.RemoteAddr,
"RemoteHost": reqHost,
"Req": req,
},
RequestID: &requestID,
})
if err != nil {
log.Error("Could not execute access logger template: %v", err.Error())
}
logger.Info("%s", buf.String())
})
}
}
+408
View File
@@ -0,0 +1,408 @@
// Copyright 2016 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
mc "code.gitea.io/gitea/modules/cache"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/httpcache"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web"
web_types "code.gitea.io/gitea/modules/web/types"
"gitea.com/go-chi/cache"
)
// APIContext is a specific context for API service
type APIContext struct {
*Base
Cache cache.Cache
Doer *user_model.User // current signed-in user
IsSigned bool
IsBasicAuth bool
ContextUser *user_model.User // the user which is being visited, in most cases it differs from Doer
Repo *Repository
Org *APIOrganization
Package *Package
}
func init() {
web.RegisterResponseStatusProvider[*APIContext](func(req *http.Request) web_types.ResponseStatusProvider {
return req.Context().Value(apiContextKey).(*APIContext)
})
}
// Currently, we have the following common fields in error response:
// * message: the message for end users (it shouldn't be used for error type detection)
// if we need to indicate some errors, we should introduce some new fields like ErrorCode or ErrorType
// * url: the swagger document URL
// APIError is error format response
// swagger:response error
type APIError struct {
Message string `json:"message"`
URL string `json:"url"`
}
// APIValidationError is error format response related to input validation
// swagger:response validationError
type APIValidationError struct {
Message string `json:"message"`
URL string `json:"url"`
}
// APIInvalidTopicsError is error format response to invalid topics
// swagger:response invalidTopicsError
type APIInvalidTopicsError struct {
Message string `json:"message"`
InvalidTopics []string `json:"invalidTopics"`
}
// APIEmpty is an empty response
// swagger:response empty
type APIEmpty struct{}
// APIForbiddenError is a forbidden error response
// swagger:response forbidden
type APIForbiddenError struct {
APIError
}
// APINotFound is a not found empty response
// swagger:response notFound
type APINotFound struct{}
// APIConflict is a conflict empty response
// swagger:response conflict
type APIConflict struct{}
// APIRedirect is a redirect response
// swagger:response redirect
type APIRedirect struct{}
// APIString is a string response
// swagger:response string
type APIString string
// APIRepoArchivedError is an error that is raised when an archived repo should be modified
// swagger:response repoArchivedError
type APIRepoArchivedError struct {
APIError
}
// ServerError responds with error message, status is 500
func (ctx *APIContext) ServerError(title string, err error) {
ctx.Error(http.StatusInternalServerError, title, err)
}
// Error responds with an error message to client with given obj as the message.
// If status is 500, also it prints error to log.
func (ctx *APIContext) Error(status int, title string, obj any) {
var message string
if err, ok := obj.(error); ok {
message = err.Error()
} else {
message = fmt.Sprintf("%s", obj)
}
if status == http.StatusInternalServerError {
log.ErrorWithSkip(1, "%s: %s", title, message)
if setting.IsProd && !(ctx.Doer != nil && ctx.Doer.IsAdmin) {
message = ""
}
}
ctx.JSON(status, APIError{
Message: message,
URL: setting.API.SwaggerURL,
})
}
// InternalServerError responds with an error message to the client with the error as a message
// and the file and line of the caller.
func (ctx *APIContext) InternalServerError(err error) {
log.ErrorWithSkip(1, "InternalServerError: %v", err)
var message string
if !setting.IsProd || (ctx.Doer != nil && ctx.Doer.IsAdmin) {
message = err.Error()
}
ctx.JSON(http.StatusInternalServerError, APIError{
Message: message,
URL: setting.API.SwaggerURL,
})
}
type apiContextKeyType struct{}
var apiContextKey = apiContextKeyType{}
// GetAPIContext returns a context for API routes
func GetAPIContext(req *http.Request) *APIContext {
return req.Context().Value(apiContextKey).(*APIContext)
}
func genAPILinks(curURL *url.URL, total, pageSize, curPage int) []string {
page := NewPagination(total, pageSize, curPage, 0)
paginater := page.Paginater
links := make([]string, 0, 4)
if paginater.HasNext() {
u := *curURL
queries := u.Query()
queries.Set("page", fmt.Sprintf("%d", paginater.Next()))
u.RawQuery = queries.Encode()
links = append(links, fmt.Sprintf("<%s%s>; rel=\"next\"", setting.AppURL, u.RequestURI()[1:]))
}
if !paginater.IsLast() {
u := *curURL
queries := u.Query()
queries.Set("page", fmt.Sprintf("%d", paginater.TotalPages()))
u.RawQuery = queries.Encode()
links = append(links, fmt.Sprintf("<%s%s>; rel=\"last\"", setting.AppURL, u.RequestURI()[1:]))
}
if !paginater.IsFirst() {
u := *curURL
queries := u.Query()
queries.Set("page", "1")
u.RawQuery = queries.Encode()
links = append(links, fmt.Sprintf("<%s%s>; rel=\"first\"", setting.AppURL, u.RequestURI()[1:]))
}
if paginater.HasPrevious() {
u := *curURL
queries := u.Query()
queries.Set("page", fmt.Sprintf("%d", paginater.Previous()))
u.RawQuery = queries.Encode()
links = append(links, fmt.Sprintf("<%s%s>; rel=\"prev\"", setting.AppURL, u.RequestURI()[1:]))
}
return links
}
// SetLinkHeader sets pagination link header by given total number and page size.
func (ctx *APIContext) SetLinkHeader(total, pageSize int) {
links := genAPILinks(ctx.Req.URL, total, pageSize, ctx.FormInt("page"))
if len(links) > 0 {
ctx.RespHeader().Set("Link", strings.Join(links, ","))
ctx.AppendAccessControlExposeHeaders("Link")
}
}
// APIContexter returns apicontext as middleware
func APIContexter() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
base, baseCleanUp := NewBaseContext(w, req)
ctx := &APIContext{
Base: base,
Cache: mc.GetCache(),
Repo: &Repository{PullRequest: &PullRequest{}},
Org: &APIOrganization{},
}
defer baseCleanUp()
ctx.Base.AppendContextValue(apiContextKey, ctx)
ctx.Base.AppendContextValueFunc(gitrepo.RepositoryContextKey, func() any { return ctx.Repo.GitRepo })
// If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid.
if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") {
if err := ctx.Req.ParseMultipartForm(setting.Attachment.MaxSize << 20); err != nil && !strings.Contains(err.Error(), "EOF") { // 32MB max size
ctx.InternalServerError(err)
return
}
}
httpcache.SetCacheControlInHeader(ctx.Resp.Header(), 0, "no-transform")
ctx.Resp.Header().Set(`X-Frame-Options`, setting.CORSConfig.XFrameOptions)
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}
// NotFound handles 404s for APIContext
// String will replace message, errors will be added to a slice
func (ctx *APIContext) NotFound(objs ...any) {
message := ctx.Locale.TrString("error.not_found")
var errors []string
for _, obj := range objs {
// Ignore nil
if obj == nil {
continue
}
if err, ok := obj.(error); ok {
errors = append(errors, err.Error())
} else {
message = obj.(string)
}
}
ctx.JSON(http.StatusNotFound, map[string]any{
"message": message,
"url": setting.API.SwaggerURL,
"errors": errors,
})
}
// ReferencesGitRepo injects the GitRepo into the Context
// you can optional skip the IsEmpty check
func ReferencesGitRepo(allowEmpty ...bool) func(ctx *APIContext) (cancel context.CancelFunc) {
return func(ctx *APIContext) (cancel context.CancelFunc) {
// Empty repository does not have reference information.
if ctx.Repo.Repository.IsEmpty && !(len(allowEmpty) != 0 && allowEmpty[0]) {
return nil
}
// For API calls.
if ctx.Repo.GitRepo == nil {
gitRepo, err := gitrepo.OpenRepository(ctx, ctx.Repo.Repository)
if err != nil {
ctx.Error(http.StatusInternalServerError, fmt.Sprintf("Open Repository %v failed", ctx.Repo.Repository.FullName()), err)
return cancel
}
ctx.Repo.GitRepo = gitRepo
// We opened it, we should close it
return func() {
// If it's been set to nil then assume someone else has closed it.
if ctx.Repo.GitRepo != nil {
_ = ctx.Repo.GitRepo.Close()
}
}
}
return cancel
}
}
// RepoRefForAPI handles repository reference names when the ref name is not explicitly given
func RepoRefForAPI(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := GetAPIContext(req)
if ctx.Repo.GitRepo == nil {
ctx.InternalServerError(fmt.Errorf("no open git repo"))
return
}
if ref := ctx.FormTrim("ref"); len(ref) > 0 {
commit, err := ctx.Repo.GitRepo.GetCommit(ref)
if err != nil {
if git.IsErrNotExist(err) {
ctx.NotFound()
} else {
ctx.Error(http.StatusInternalServerError, "GetCommit", err)
}
return
}
ctx.Repo.Commit = commit
ctx.Repo.CommitID = ctx.Repo.Commit.ID.String()
ctx.Repo.TreePath = ctx.Params("*")
next.ServeHTTP(w, req)
return
}
refName := getRefName(ctx.Base, ctx.Repo, RepoRefAny)
var err error
if ctx.Repo.GitRepo.IsBranchExist(refName) {
ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetBranchCommit(refName)
if err != nil {
ctx.InternalServerError(err)
return
}
ctx.Repo.CommitID = ctx.Repo.Commit.ID.String()
} else if ctx.Repo.GitRepo.IsTagExist(refName) {
ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetTagCommit(refName)
if err != nil {
ctx.InternalServerError(err)
return
}
ctx.Repo.CommitID = ctx.Repo.Commit.ID.String()
} else if len(refName) == ctx.Repo.GetObjectFormat().FullLength() {
ctx.Repo.CommitID = refName
ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetCommit(refName)
if err != nil {
ctx.NotFound("GetCommit", err)
return
}
} else {
ctx.NotFound(fmt.Errorf("not exist: '%s'", ctx.Params("*")))
return
}
next.ServeHTTP(w, req)
})
}
// HasAPIError returns true if error occurs in form validation.
func (ctx *APIContext) HasAPIError() bool {
hasErr, ok := ctx.Data["HasError"]
if !ok {
return false
}
return hasErr.(bool)
}
// GetErrMsg returns error message in form validation.
func (ctx *APIContext) GetErrMsg() string {
msg, _ := ctx.Data["ErrorMsg"].(string)
if msg == "" {
msg = "invalid form data"
}
return msg
}
// NotFoundOrServerError use error check function to determine if the error
// is about not found. It responds with 404 status code for not found error,
// or error context description for logging purpose of 500 server error.
func (ctx *APIContext) NotFoundOrServerError(logMsg string, errCheck func(error) bool, logErr error) {
if errCheck(logErr) {
ctx.JSON(http.StatusNotFound, nil)
return
}
ctx.Error(http.StatusInternalServerError, "NotFoundOrServerError", logMsg)
}
// IsUserSiteAdmin returns true if current user is a site admin
func (ctx *APIContext) IsUserSiteAdmin() bool {
return ctx.IsSigned && ctx.Doer.IsAdmin
}
// IsUserRepoAdmin returns true if current user is admin in current repo
func (ctx *APIContext) IsUserRepoAdmin() bool {
return ctx.Repo.IsAdmin()
}
// IsUserRepoWriter returns true if current user has write privilege in current repo
func (ctx *APIContext) IsUserRepoWriter(unitTypes []unit.Type) bool {
for _, unitType := range unitTypes {
if ctx.Repo.CanWrite(unitType) {
return true
}
}
return false
}
+12
View File
@@ -0,0 +1,12 @@
// Copyright 2016 The Gogs Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import "code.gitea.io/gitea/models/organization"
// APIOrganization contains organization and team
type APIOrganization struct {
Organization *organization.Organization
Team *organization.Team
}
+50
View File
@@ -0,0 +1,50 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"net/url"
"strconv"
"testing"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestGenAPILinks(t *testing.T) {
setting.AppURL = "http://localhost:3000/"
kases := map[string][]string{
"api/v1/repos/jerrykan/example-repo/issues?state=all": {
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=2&state=all>; rel="next"`,
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=5&state=all>; rel="last"`,
},
"api/v1/repos/jerrykan/example-repo/issues?state=all&page=1": {
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=2&state=all>; rel="next"`,
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=5&state=all>; rel="last"`,
},
"api/v1/repos/jerrykan/example-repo/issues?state=all&page=2": {
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=3&state=all>; rel="next"`,
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=5&state=all>; rel="last"`,
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=1&state=all>; rel="first"`,
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=1&state=all>; rel="prev"`,
},
"api/v1/repos/jerrykan/example-repo/issues?state=all&page=5": {
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=1&state=all>; rel="first"`,
`<http://localhost:3000/api/v1/repos/jerrykan/example-repo/issues?page=4&state=all>; rel="prev"`,
},
}
for req, response := range kases {
u, err := url.Parse(setting.AppURL + req)
assert.NoError(t, err)
p := u.Query().Get("page")
curPage, _ := strconv.Atoi(p)
links := genAPILinks(u, 100, 20, curPage)
assert.EqualValues(t, links, response)
}
}
+317
View File
@@ -0,0 +1,317 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"context"
"fmt"
"html/template"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"code.gitea.io/gitea/modules/httplib"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/translation"
"code.gitea.io/gitea/modules/web/middleware"
"github.com/go-chi/chi/v5"
)
type contextValuePair struct {
key any
valueFn func() any
}
type Base struct {
originCtx context.Context
contextValues []contextValuePair
Resp ResponseWriter
Req *http.Request
// Data is prepared by ContextDataStore middleware, this field only refers to the pre-created/prepared ContextData.
// Although it's mainly used for MVC templates, sometimes it's also used to pass data between middlewares/handler
Data middleware.ContextData
// Locale is mainly for Web context, although the API context also uses it in some cases: message response, form validation
Locale translation.Locale
}
func (b *Base) Deadline() (deadline time.Time, ok bool) {
return b.originCtx.Deadline()
}
func (b *Base) Done() <-chan struct{} {
return b.originCtx.Done()
}
func (b *Base) Err() error {
return b.originCtx.Err()
}
func (b *Base) Value(key any) any {
for _, pair := range b.contextValues {
if pair.key == key {
return pair.valueFn()
}
}
return b.originCtx.Value(key)
}
func (b *Base) AppendContextValueFunc(key any, valueFn func() any) any {
b.contextValues = append(b.contextValues, contextValuePair{key, valueFn})
return b
}
func (b *Base) AppendContextValue(key, value any) any {
b.contextValues = append(b.contextValues, contextValuePair{key, func() any { return value }})
return b
}
func (b *Base) GetData() middleware.ContextData {
return b.Data
}
// AppendAccessControlExposeHeaders append headers by name to "Access-Control-Expose-Headers" header
func (b *Base) AppendAccessControlExposeHeaders(names ...string) {
val := b.RespHeader().Get("Access-Control-Expose-Headers")
if len(val) != 0 {
b.RespHeader().Set("Access-Control-Expose-Headers", fmt.Sprintf("%s, %s", val, strings.Join(names, ", ")))
} else {
b.RespHeader().Set("Access-Control-Expose-Headers", strings.Join(names, ", "))
}
}
// SetTotalCountHeader set "X-Total-Count" header
func (b *Base) SetTotalCountHeader(total int64) {
b.RespHeader().Set("X-Total-Count", fmt.Sprint(total))
b.AppendAccessControlExposeHeaders("X-Total-Count")
}
// Written returns true if there are something sent to web browser
func (b *Base) Written() bool {
return b.Resp.WrittenStatus() != 0
}
func (b *Base) WrittenStatus() int {
return b.Resp.WrittenStatus()
}
// Status writes status code
func (b *Base) Status(status int) {
b.Resp.WriteHeader(status)
}
// Write writes data to web browser
func (b *Base) Write(bs []byte) (int, error) {
return b.Resp.Write(bs)
}
// RespHeader returns the response header
func (b *Base) RespHeader() http.Header {
return b.Resp.Header()
}
// Error returned an error to web browser
func (b *Base) Error(status int, contents ...string) {
v := http.StatusText(status)
if len(contents) > 0 {
v = contents[0]
}
http.Error(b.Resp, v, status)
}
// JSON render content as JSON
func (b *Base) JSON(status int, content any) {
b.Resp.Header().Set("Content-Type", "application/json;charset=utf-8")
b.Resp.WriteHeader(status)
if err := json.NewEncoder(b.Resp).Encode(content); err != nil {
log.Error("Render JSON failed: %v", err)
}
}
// RemoteAddr returns the client machine ip address
func (b *Base) RemoteAddr() string {
return b.Req.RemoteAddr
}
// Params returns the param on route
func (b *Base) Params(p string) string {
s, _ := url.PathUnescape(chi.URLParam(b.Req, strings.TrimPrefix(p, ":")))
return s
}
func (b *Base) PathParamRaw(p string) string {
return chi.URLParam(b.Req, strings.TrimPrefix(p, ":"))
}
// ParamsInt64 returns the param on route as int64
func (b *Base) ParamsInt64(p string) int64 {
v, _ := strconv.ParseInt(b.Params(p), 10, 64)
return v
}
// SetParams set params into routes
func (b *Base) SetParams(k, v string) {
chiCtx := chi.RouteContext(b)
chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v))
}
// FormString returns the first value matching the provided key in the form as a string
func (b *Base) FormString(key string) string {
return b.Req.FormValue(key)
}
// FormStrings returns a string slice for the provided key from the form
func (b *Base) FormStrings(key string) []string {
if b.Req.Form == nil {
if err := b.Req.ParseMultipartForm(32 << 20); err != nil {
return nil
}
}
if v, ok := b.Req.Form[key]; ok {
return v
}
return nil
}
// FormTrim returns the first value for the provided key in the form as a space trimmed string
func (b *Base) FormTrim(key string) string {
return strings.TrimSpace(b.Req.FormValue(key))
}
// FormInt returns the first value for the provided key in the form as an int
func (b *Base) FormInt(key string) int {
v, _ := strconv.Atoi(b.Req.FormValue(key))
return v
}
// FormInt64 returns the first value for the provided key in the form as an int64
func (b *Base) FormInt64(key string) int64 {
v, _ := strconv.ParseInt(b.Req.FormValue(key), 10, 64)
return v
}
// FormBool returns true if the value for the provided key in the form is "1", "true" or "on"
func (b *Base) FormBool(key string) bool {
s := b.Req.FormValue(key)
v, _ := strconv.ParseBool(s)
v = v || strings.EqualFold(s, "on")
return v
}
// FormOptionalBool returns an optional.Some(true) or optional.Some(false) if the value
// for the provided key exists in the form else it returns optional.None[bool]()
func (b *Base) FormOptionalBool(key string) optional.Option[bool] {
value := b.Req.FormValue(key)
if len(value) == 0 {
return optional.None[bool]()
}
s := b.Req.FormValue(key)
v, _ := strconv.ParseBool(s)
v = v || strings.EqualFold(s, "on")
return optional.Some(v)
}
func (b *Base) SetFormString(key, value string) {
_ = b.Req.FormValue(key) // force parse form
b.Req.Form.Set(key, value)
}
// PlainTextBytes renders bytes as plain text
func (b *Base) plainTextInternal(skip, status int, bs []byte) {
statusPrefix := status / 100
if statusPrefix == 4 || statusPrefix == 5 {
log.Log(skip, log.TRACE, "plainTextInternal (status=%d): %s", status, string(bs))
}
b.Resp.Header().Set("Content-Type", "text/plain;charset=utf-8")
b.Resp.Header().Set("X-Content-Type-Options", "nosniff")
b.Resp.WriteHeader(status)
if _, err := b.Resp.Write(bs); err != nil {
log.ErrorWithSkip(skip, "plainTextInternal (status=%d): write bytes failed: %v", status, err)
}
}
// PlainTextBytes renders bytes as plain text
func (b *Base) PlainTextBytes(status int, bs []byte) {
b.plainTextInternal(2, status, bs)
}
// PlainText renders content as plain text
func (b *Base) PlainText(status int, text string) {
b.plainTextInternal(2, status, []byte(text))
}
// Redirect redirects the request
func (b *Base) Redirect(location string, status ...int) {
code := http.StatusSeeOther
if len(status) == 1 {
code = status[0]
}
if strings.HasPrefix(location, "http://") || strings.HasPrefix(location, "https://") || strings.HasPrefix(location, "//") {
// Some browsers (Safari) have buggy behavior for Cookie + Cache + External Redirection, eg: /my-path => https://other/path
// 1. the first request to "/my-path" contains cookie
// 2. some time later, the request to "/my-path" doesn't contain cookie (caused by Prevent web tracking)
// 3. Gitea's Sessioner doesn't see the session cookie, so it generates a new session id, and returns it to browser
// 4. then the browser accepts the empty session, then the user is logged out
// So in this case, we should remove the session cookie from the response header
removeSessionCookieHeader(b.Resp)
}
// in case the request is made by htmx, have it redirect the browser instead of trying to follow the redirect inside htmx
if b.Req.Header.Get("HX-Request") == "true" {
b.Resp.Header().Set("HX-Redirect", location)
// we have to return a non-redirect status code so XMLHTTPRequest will not immediately follow the redirect
// so as to give htmx redirect logic a chance to run
b.Status(http.StatusNoContent)
return
}
http.Redirect(b.Resp, b.Req, location, code)
}
type ServeHeaderOptions httplib.ServeHeaderOptions
func (b *Base) SetServeHeaders(opt *ServeHeaderOptions) {
httplib.ServeSetHeaders(b.Resp, (*httplib.ServeHeaderOptions)(opt))
}
// ServeContent serves content to http request
func (b *Base) ServeContent(r io.ReadSeeker, opts *ServeHeaderOptions) {
httplib.ServeSetHeaders(b.Resp, (*httplib.ServeHeaderOptions)(opts))
http.ServeContent(b.Resp, b.Req, opts.Filename, opts.LastModified, r)
}
// Close frees all resources hold by Context
func (b *Base) cleanUp() {
if b.Req != nil && b.Req.MultipartForm != nil {
_ = b.Req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory
}
}
func (b *Base) Tr(msg string, args ...any) template.HTML {
return b.Locale.Tr(msg, args...)
}
func (b *Base) TrN(cnt any, key1, keyN string, args ...any) template.HTML {
return b.Locale.TrN(cnt, key1, keyN, args...)
}
func NewBaseContext(resp http.ResponseWriter, req *http.Request) (b *Base, closeFunc func()) {
b = &Base{
originCtx: req.Context(),
Req: req,
Resp: WrapResponseWriter(resp),
Locale: middleware.Locale(resp, req),
Data: middleware.GetContextData(req.Context()),
}
b.AppendContextValue(translation.ContextKey, b.Locale)
b.Req = b.Req.WithContext(b)
return b, b.cleanUp
}
+47
View File
@@ -0,0 +1,47 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"net/http"
"net/http/httptest"
"testing"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestRedirect(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
cases := []struct {
url string
keep bool
}{
{"http://test", false},
{"https://test", false},
{"//test", false},
{"/://test", true},
{"/test", true},
}
for _, c := range cases {
resp := httptest.NewRecorder()
b, cleanup := NewBaseContext(resp, req)
resp.Header().Add("Set-Cookie", (&http.Cookie{Name: setting.SessionConfig.CookieName, Value: "dummy"}).String())
b.Redirect(c.url)
cleanup()
has := resp.Header().Get("Set-Cookie") == "i_like_gitea=dummy"
assert.Equal(t, c.keep, has, "url = %q", c.url)
}
req, _ = http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
req.Header.Add("HX-Request", "true")
b, cleanup := NewBaseContext(resp, req)
b.Redirect("/other")
cleanup()
assert.Equal(t, "/other", resp.Header().Get("HX-Redirect"))
assert.Equal(t, http.StatusNoContent, resp.Code)
}
+93
View File
@@ -0,0 +1,93 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"fmt"
"sync"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/cache"
"code.gitea.io/gitea/modules/hcaptcha"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/mcaptcha"
"code.gitea.io/gitea/modules/recaptcha"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/turnstile"
"gitea.com/go-chi/captcha"
)
var (
imageCaptchaOnce sync.Once
cpt *captcha.Captcha
)
// GetImageCaptcha returns global image captcha
func GetImageCaptcha() *captcha.Captcha {
imageCaptchaOnce.Do(func() {
cpt = captcha.NewCaptcha(captcha.Options{
SubURL: setting.AppSubURL,
})
cpt.Store = cache.GetCache()
})
return cpt
}
// SetCaptchaData sets common captcha data
func SetCaptchaData(ctx *Context) {
if !setting.Service.EnableCaptcha {
return
}
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha
ctx.Data["RecaptchaURL"] = setting.Service.RecaptchaURL
ctx.Data["Captcha"] = GetImageCaptcha()
ctx.Data["CaptchaType"] = setting.Service.CaptchaType
ctx.Data["RecaptchaSitekey"] = setting.Service.RecaptchaSitekey
ctx.Data["HcaptchaSitekey"] = setting.Service.HcaptchaSitekey
ctx.Data["McaptchaSitekey"] = setting.Service.McaptchaSitekey
ctx.Data["McaptchaURL"] = setting.Service.McaptchaURL
ctx.Data["CfTurnstileSitekey"] = setting.Service.CfTurnstileSitekey
}
const (
gRecaptchaResponseField = "g-recaptcha-response"
hCaptchaResponseField = "h-captcha-response"
mCaptchaResponseField = "m-captcha-response"
cfTurnstileResponseField = "cf-turnstile-response"
)
// VerifyCaptcha verifies Captcha data
// No-op if captchas are not enabled
func VerifyCaptcha(ctx *Context, tpl base.TplName, form any) {
if !setting.Service.EnableCaptcha {
return
}
var valid bool
var err error
switch setting.Service.CaptchaType {
case setting.ImageCaptcha:
valid = GetImageCaptcha().VerifyReq(ctx.Req)
case setting.ReCaptcha:
valid, err = recaptcha.Verify(ctx, ctx.Req.Form.Get(gRecaptchaResponseField))
case setting.HCaptcha:
valid, err = hcaptcha.Verify(ctx, ctx.Req.Form.Get(hCaptchaResponseField))
case setting.MCaptcha:
valid, err = mcaptcha.Verify(ctx, ctx.Req.Form.Get(mCaptchaResponseField))
case setting.CfTurnstile:
valid, err = turnstile.Verify(ctx, ctx.Req.Form.Get(cfTurnstileResponseField))
default:
ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType))
return
}
if err != nil {
log.Debug("%v", err)
}
if !valid {
ctx.Data["Err_Captcha"] = true
ctx.RenderWithErr(ctx.Tr("form.captcha_incorrect"), tpl, form)
}
}
+257
View File
@@ -0,0 +1,257 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"context"
"encoding/hex"
"fmt"
"html/template"
"io"
"net/http"
"net/url"
"strings"
"time"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
mc "code.gitea.io/gitea/modules/cache"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/httpcache"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/templates"
"code.gitea.io/gitea/modules/translation"
"code.gitea.io/gitea/modules/web"
"code.gitea.io/gitea/modules/web/middleware"
web_types "code.gitea.io/gitea/modules/web/types"
"gitea.com/go-chi/cache"
"gitea.com/go-chi/session"
)
// Render represents a template render
type Render interface {
TemplateLookup(tmpl string, templateCtx context.Context) (templates.TemplateExecutor, error)
HTML(w io.Writer, status int, name string, data any, templateCtx context.Context) error
}
// Context represents context of a request.
type Context struct {
*Base
TemplateContext TemplateContext
Render Render
PageData map[string]any // data used by JavaScript modules in one page, it's `window.config.pageData`
Cache cache.Cache
Csrf CSRFProtector
Flash *middleware.Flash
Session session.Store
Link string // current request URL (without query string)
Doer *user_model.User // current signed-in user
IsSigned bool
IsBasicAuth bool
ContextUser *user_model.User // the user which is being visited, in most cases it differs from Doer
Repo *Repository
Org *Organization
Package *Package
}
type TemplateContext map[string]any
func init() {
web.RegisterResponseStatusProvider[*Context](func(req *http.Request) web_types.ResponseStatusProvider {
return req.Context().Value(WebContextKey).(*Context)
})
}
type webContextKeyType struct{}
var WebContextKey = webContextKeyType{}
func GetWebContext(req *http.Request) *Context {
ctx, _ := req.Context().Value(WebContextKey).(*Context)
return ctx
}
// ValidateContext is a special context for form validation middleware. It may be different from other contexts.
type ValidateContext struct {
*Base
}
// GetValidateContext gets a context for middleware form validation
func GetValidateContext(req *http.Request) (ctx *ValidateContext) {
if ctxAPI, ok := req.Context().Value(apiContextKey).(*APIContext); ok {
ctx = &ValidateContext{Base: ctxAPI.Base}
} else if ctxWeb, ok := req.Context().Value(WebContextKey).(*Context); ok {
ctx = &ValidateContext{Base: ctxWeb.Base}
} else {
panic("invalid context, expect either APIContext or Context")
}
return ctx
}
func NewTemplateContextForWeb(ctx *Context) TemplateContext {
tmplCtx := NewTemplateContext(ctx)
tmplCtx["Locale"] = ctx.Base.Locale
tmplCtx["AvatarUtils"] = templates.NewAvatarUtils(ctx)
return tmplCtx
}
func NewWebContext(base *Base, render Render, session session.Store) *Context {
ctx := &Context{
Base: base,
Render: render,
Session: session,
Cache: mc.GetCache(),
Link: setting.AppSubURL + strings.TrimSuffix(base.Req.URL.EscapedPath(), "/"),
Repo: &Repository{PullRequest: &PullRequest{}},
Org: &Organization{},
}
ctx.TemplateContext = NewTemplateContextForWeb(ctx)
ctx.Flash = &middleware.Flash{DataStore: ctx, Values: url.Values{}}
return ctx
}
// Contexter initializes a classic context for a request.
func Contexter() func(next http.Handler) http.Handler {
rnd := templates.HTMLRenderer()
csrfOpts := CsrfOptions{
Secret: hex.EncodeToString(setting.GetGeneralTokenSigningSecret()),
Cookie: setting.CSRFCookieName,
SetCookie: true,
Secure: setting.SessionConfig.Secure,
CookieHTTPOnly: setting.CSRFCookieHTTPOnly,
Header: "X-Csrf-Token",
CookieDomain: setting.SessionConfig.Domain,
CookiePath: setting.SessionConfig.CookiePath,
SameSite: setting.SessionConfig.SameSite,
}
if !setting.IsProd {
CsrfTokenRegenerationInterval = 5 * time.Second // in dev, re-generate the tokens more aggressively for debug purpose
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
base, baseCleanUp := NewBaseContext(resp, req)
defer baseCleanUp()
ctx := NewWebContext(base, rnd, session.GetSession(req))
ctx.Data.MergeFrom(middleware.CommonTemplateContextData())
ctx.Data["Context"] = ctx // TODO: use "ctx" in template and remove this
ctx.Data["CurrentURL"] = setting.AppSubURL + req.URL.RequestURI()
ctx.Data["Link"] = ctx.Link
// PageData is passed by reference, and it will be rendered to `window.config.pageData` in `head.tmpl` for JavaScript modules
ctx.PageData = map[string]any{}
ctx.Data["PageData"] = ctx.PageData
ctx.Base.AppendContextValue(WebContextKey, ctx)
ctx.Base.AppendContextValueFunc(gitrepo.RepositoryContextKey, func() any { return ctx.Repo.GitRepo })
ctx.Csrf = PrepareCSRFProtector(csrfOpts, ctx)
// Get the last flash message from cookie
lastFlashCookie := middleware.GetSiteCookie(ctx.Req, CookieNameFlash)
if vals, _ := url.ParseQuery(lastFlashCookie); len(vals) > 0 {
// store last Flash message into the template data, to render it
ctx.Data["Flash"] = &middleware.Flash{
DataStore: ctx,
Values: vals,
ErrorMsg: vals.Get("error"),
SuccessMsg: vals.Get("success"),
InfoMsg: vals.Get("info"),
WarningMsg: vals.Get("warning"),
}
}
// if there are new messages in the ctx.Flash, write them into cookie
ctx.Resp.Before(func(resp ResponseWriter) {
if val := ctx.Flash.Encode(); val != "" {
middleware.SetSiteCookie(ctx.Resp, CookieNameFlash, val, 0)
} else if lastFlashCookie != "" {
middleware.SetSiteCookie(ctx.Resp, CookieNameFlash, "", -1)
}
})
// If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid.
if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") {
if err := ctx.Req.ParseMultipartForm(setting.Attachment.MaxSize << 20); err != nil && !strings.Contains(err.Error(), "EOF") { // 32MB max size
ctx.ServerError("ParseMultipartForm", err)
return
}
}
httpcache.SetCacheControlInHeader(ctx.Resp.Header(), 0, "no-transform")
ctx.Resp.Header().Set(`X-Frame-Options`, setting.CORSConfig.XFrameOptions)
ctx.Data["SystemConfig"] = setting.Config()
ctx.Data["CsrfToken"] = ctx.Csrf.GetToken()
ctx.Data["CsrfTokenHtml"] = template.HTML(`<input type="hidden" name="_csrf" value="` + ctx.Data["CsrfToken"].(string) + `">`)
// FIXME: do we really always need these setting? There should be someway to have to avoid having to always set these
ctx.Data["DisableMigrations"] = setting.Repository.DisableMigrations
ctx.Data["DisableStars"] = setting.Repository.DisableStars
ctx.Data["EnableActions"] = setting.Actions.Enabled
ctx.Data["ManifestData"] = setting.ManifestData
ctx.Data["UnitWikiGlobalDisabled"] = unit.TypeWiki.UnitGlobalDisabled()
ctx.Data["UnitIssuesGlobalDisabled"] = unit.TypeIssues.UnitGlobalDisabled()
ctx.Data["UnitPullsGlobalDisabled"] = unit.TypePullRequests.UnitGlobalDisabled()
ctx.Data["UnitProjectsGlobalDisabled"] = unit.TypeProjects.UnitGlobalDisabled()
ctx.Data["UnitActionsGlobalDisabled"] = unit.TypeActions.UnitGlobalDisabled()
ctx.Data["AllLangs"] = translation.AllLangs()
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}
// HasError returns true if error occurs in form validation.
// Attention: this function changes ctx.Data and ctx.Flash
func (ctx *Context) HasError() bool {
hasErr, ok := ctx.Data["HasError"]
if !ok {
return false
}
ctx.Flash.ErrorMsg = ctx.GetErrMsg()
ctx.Data["Flash"] = ctx.Flash
return hasErr.(bool)
}
// GetErrMsg returns error message in form validation.
func (ctx *Context) GetErrMsg() string {
msg, _ := ctx.Data["ErrorMsg"].(string)
if msg == "" {
msg = "invalid form data"
}
return msg
}
func (ctx *Context) JSONRedirect(redirect string) {
ctx.JSON(http.StatusOK, map[string]any{"redirect": redirect})
}
func (ctx *Context) JSONOK() {
ctx.JSON(http.StatusOK, map[string]any{"ok": true}) // this is only a dummy response, frontend seldom uses it
}
func (ctx *Context) JSONError(msg any) {
switch v := msg.(type) {
case string:
ctx.JSON(http.StatusBadRequest, map[string]any{"errorMessage": v, "renderFormat": "text"})
case template.HTML:
ctx.JSON(http.StatusBadRequest, map[string]any{"errorMessage": v, "renderFormat": "html"})
default:
panic(fmt.Sprintf("unsupported type: %T", msg))
}
}
+42
View File
@@ -0,0 +1,42 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"net/http"
"strings"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"
)
const CookieNameFlash = "gitea_flash"
func removeSessionCookieHeader(w http.ResponseWriter) {
cookies := w.Header()["Set-Cookie"]
w.Header().Del("Set-Cookie")
for _, cookie := range cookies {
if strings.HasPrefix(cookie, setting.SessionConfig.CookieName+"=") {
continue
}
w.Header().Add("Set-Cookie", cookie)
}
}
// SetSiteCookie convenience function to set most cookies consistently
// CSRF and a few others are the exception here
func (ctx *Context) SetSiteCookie(name, value string, maxAge int) {
middleware.SetSiteCookie(ctx.Resp, name, value, maxAge)
}
// DeleteSiteCookie convenience function to delete most cookies consistently
// CSRF and a few others are the exception here
func (ctx *Context) DeleteSiteCookie(name string) {
middleware.SetSiteCookie(ctx.Resp, name, "", -1)
}
// GetSiteCookie returns given cookie value from request header.
func (ctx *Context) GetSiteCookie(name string) string {
return middleware.GetSiteCookie(ctx.Req, name)
}
+29
View File
@@ -0,0 +1,29 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"code.gitea.io/gitea/models/unit"
)
// IsUserSiteAdmin returns true if current user is a site admin
func (ctx *Context) IsUserSiteAdmin() bool {
return ctx.IsSigned && ctx.Doer.IsAdmin
}
// IsUserRepoAdmin returns true if current user is admin in current repo
func (ctx *Context) IsUserRepoAdmin() bool {
return ctx.Repo.IsAdmin()
}
// IsUserRepoWriter returns true if current user has write privilege in current repo
func (ctx *Context) IsUserRepoWriter(unitTypes []unit.Type) bool {
for _, unitType := range unitTypes {
if ctx.Repo.CanWrite(unitType) {
return true
}
}
return false
}
+32
View File
@@ -0,0 +1,32 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"io"
"net/http"
"strings"
)
// UploadStream returns the request body or the first form file
// Only form files need to get closed.
func (ctx *Context) UploadStream() (rd io.ReadCloser, needToClose bool, err error) {
contentType := strings.ToLower(ctx.Req.Header.Get("Content-Type"))
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") || strings.HasPrefix(contentType, "multipart/form-data") {
if err := ctx.Req.ParseMultipartForm(32 << 20); err != nil {
return nil, false, err
}
if ctx.Req.MultipartForm.File == nil {
return nil, false, http.ErrMissingFile
}
for _, files := range ctx.Req.MultipartForm.File {
if len(files) > 0 {
r, err := files[0].Open()
return r, true, err
}
}
return nil, false, http.ErrMissingFile
}
return ctx.Req.Body, false, nil
}
+190
View File
@@ -0,0 +1,190 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"errors"
"fmt"
"html/template"
"net"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/httplib"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/templates"
"code.gitea.io/gitea/modules/web/middleware"
)
// RedirectToUser redirect to a differently-named user
func RedirectToUser(ctx *Base, userName string, redirectUserID int64) {
user, err := user_model.GetUserByID(ctx, redirectUserID)
if err != nil {
ctx.Error(http.StatusInternalServerError, "unable to get user")
return
}
redirectPath := strings.Replace(
ctx.Req.URL.EscapedPath(),
url.PathEscape(userName),
url.PathEscape(user.Name),
1,
)
if ctx.Req.URL.RawQuery != "" {
redirectPath += "?" + ctx.Req.URL.RawQuery
}
ctx.Redirect(path.Join(setting.AppSubURL, redirectPath), http.StatusTemporaryRedirect)
}
// RedirectToCurrentSite redirects to first not empty URL which belongs to current site
func (ctx *Context) RedirectToCurrentSite(location ...string) {
for _, loc := range location {
if len(loc) == 0 {
continue
}
if !httplib.IsCurrentGiteaSiteURL(loc) {
continue
}
ctx.Redirect(loc)
return
}
ctx.Redirect(setting.AppSubURL + "/")
}
const tplStatus500 base.TplName = "status/500"
// HTML calls Context.HTML and renders the template to HTTP response
func (ctx *Context) HTML(status int, name base.TplName) {
log.Debug("Template: %s", name)
tmplStartTime := time.Now()
if !setting.IsProd {
ctx.Data["TemplateName"] = name
}
ctx.Data["TemplateLoadTimes"] = func() string {
return strconv.FormatInt(time.Since(tmplStartTime).Nanoseconds()/1e6, 10) + "ms"
}
err := ctx.Render.HTML(ctx.Resp, status, string(name), ctx.Data, ctx.TemplateContext)
if err == nil {
return
}
// if rendering fails, show error page
if name != tplStatus500 {
err = fmt.Errorf("failed to render template: %s, error: %s", name, templates.HandleTemplateRenderingError(err))
ctx.ServerError("Render failed", err) // show the 500 error page
} else {
ctx.PlainText(http.StatusInternalServerError, "Unable to render status/500 page, the template system is broken, or Gitea can't find your template files.")
return
}
}
// JSONTemplate renders the template as JSON response
// keep in mind that the template is processed in HTML context, so JSON-things should be handled carefully, eg: by JSEscape
func (ctx *Context) JSONTemplate(tmpl base.TplName) {
t, err := ctx.Render.TemplateLookup(string(tmpl), nil)
if err != nil {
ctx.ServerError("unable to find template", err)
return
}
ctx.Resp.Header().Set("Content-Type", "application/json")
if err = t.Execute(ctx.Resp, ctx.Data); err != nil {
ctx.ServerError("unable to execute template", err)
}
}
// RenderToHTML renders the template content to a HTML string
func (ctx *Context) RenderToHTML(name base.TplName, data map[string]any) (template.HTML, error) {
var buf strings.Builder
err := ctx.Render.HTML(&buf, 0, string(name), data, ctx.TemplateContext)
return template.HTML(buf.String()), err
}
// RenderWithErr used for page has form validation but need to prompt error to users.
func (ctx *Context) RenderWithErr(msg any, tpl base.TplName, form any) {
if form != nil {
middleware.AssignForm(form, ctx.Data)
}
ctx.Flash.Error(msg, true)
ctx.HTML(http.StatusOK, tpl)
}
// NotFound displays a 404 (Not Found) page and prints the given error, if any.
func (ctx *Context) NotFound(logMsg string, logErr error) {
ctx.notFoundInternal(logMsg, logErr)
}
func (ctx *Context) notFoundInternal(logMsg string, logErr error) {
if logErr != nil {
log.Log(2, log.DEBUG, "%s: %v", logMsg, logErr)
if !setting.IsProd {
ctx.Data["ErrorMsg"] = logErr
}
}
// response simple message if Accept isn't text/html
showHTML := false
for _, part := range ctx.Req.Header["Accept"] {
if strings.Contains(part, "text/html") {
showHTML = true
break
}
}
if !showHTML {
ctx.plainTextInternal(3, http.StatusNotFound, []byte("Not found.\n"))
return
}
ctx.Data["IsRepo"] = ctx.Repo.Repository != nil
ctx.Data["Title"] = "Page Not Found"
ctx.HTML(http.StatusNotFound, base.TplName("status/404"))
}
// ServerError displays a 500 (Internal Server Error) page and prints the given error, if any.
func (ctx *Context) ServerError(logMsg string, logErr error) {
ctx.serverErrorInternal(logMsg, logErr)
}
func (ctx *Context) serverErrorInternal(logMsg string, logErr error) {
if logErr != nil {
log.ErrorWithSkip(2, "%s: %v", logMsg, logErr)
if _, ok := logErr.(*net.OpError); ok || errors.Is(logErr, &net.OpError{}) {
// This is an error within the underlying connection
// and further rendering will not work so just return
return
}
// it's safe to show internal error to admin users, and it helps
if !setting.IsProd || (ctx.Doer != nil && ctx.Doer.IsAdmin) {
ctx.Data["ErrorMsg"] = fmt.Sprintf("%s, %s", logMsg, logErr)
}
}
ctx.Data["Title"] = "Internal Server Error"
ctx.HTML(http.StatusInternalServerError, tplStatus500)
}
// NotFoundOrServerError use error check function to determine if the error
// is about not found. It responds with 404 status code for not found error,
// or error context description for logging purpose of 500 server error.
// TODO: remove the "errCheck" and use util.ErrNotFound to check
func (ctx *Context) NotFoundOrServerError(logMsg string, errCheck func(error) bool, logErr error) {
if errCheck(logErr) {
ctx.notFoundInternal(logMsg, logErr)
return
}
ctx.serverErrorInternal(logMsg, logErr)
}
+35
View File
@@ -0,0 +1,35 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"context"
"time"
)
var _ context.Context = TemplateContext(nil)
func NewTemplateContext(ctx context.Context) TemplateContext {
return TemplateContext{"_ctx": ctx}
}
func (c TemplateContext) parentContext() context.Context {
return c["_ctx"].(context.Context)
}
func (c TemplateContext) Deadline() (deadline time.Time, ok bool) {
return c.parentContext().Deadline()
}
func (c TemplateContext) Done() <-chan struct{} {
return c.parentContext().Done()
}
func (c TemplateContext) Err() error {
return c.parentContext().Err()
}
func (c TemplateContext) Value(key any) any {
return c.parentContext().Value(key)
}
+51
View File
@@ -0,0 +1,51 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/test"
"github.com/stretchr/testify/assert"
)
func TestRemoveSessionCookieHeader(t *testing.T) {
w := httptest.NewRecorder()
w.Header().Add("Set-Cookie", (&http.Cookie{Name: setting.SessionConfig.CookieName, Value: "foo"}).String())
w.Header().Add("Set-Cookie", (&http.Cookie{Name: "other", Value: "bar"}).String())
assert.Len(t, w.Header().Values("Set-Cookie"), 2)
removeSessionCookieHeader(w)
assert.Len(t, w.Header().Values("Set-Cookie"), 1)
assert.Contains(t, "other=bar", w.Header().Get("Set-Cookie"))
}
func TestRedirectToCurrentSite(t *testing.T) {
defer test.MockVariableValue(&setting.AppURL, "http://localhost:3000/sub/")()
defer test.MockVariableValue(&setting.AppSubURL, "/sub")()
cases := []struct {
location string
want string
}{
{"/", "/sub/"},
{"http://localhost:3000/sub?k=v", "http://localhost:3000/sub?k=v"},
{"http://other", "/sub/"},
}
for _, c := range cases {
t.Run(c.location, func(t *testing.T) {
req := &http.Request{URL: &url.URL{Path: "/"}}
resp := httptest.NewRecorder()
base, baseCleanUp := NewBaseContext(resp, req)
defer baseCleanUp()
ctx := NewWebContext(base, nil, nil)
ctx.RedirectToCurrentSite(c.location)
redirect := test.RedirectURL(resp)
assert.Equal(t, c.want, redirect)
})
}
}
+242
View File
@@ -0,0 +1,242 @@
// Copyright 2013 Martini Authors
// Copyright 2014 The Macaron Authors
// Copyright 2021 The Gitea Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// SPDX-License-Identifier: Apache-2.0
// a middleware that generates and validates CSRF tokens.
package context
import (
"encoding/base32"
"fmt"
"net/http"
"strconv"
"time"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/web/middleware"
)
// CSRFProtector represents a CSRF protector and is used to get the current token and validate the token.
type CSRFProtector interface {
// GetHeaderName returns HTTP header to search for token.
GetHeaderName() string
// GetFormName returns form value to search for token.
GetFormName() string
// GetToken returns the token.
GetToken() string
// Validate validates the token in http context.
Validate(ctx *Context)
// DeleteCookie deletes the cookie
DeleteCookie(ctx *Context)
}
type csrfProtector struct {
opt CsrfOptions
// Token generated to pass via header, cookie, or hidden form value.
Token string
// This value must be unique per user.
ID string
}
// GetHeaderName returns the name of the HTTP header for csrf token.
func (c *csrfProtector) GetHeaderName() string {
return c.opt.Header
}
// GetFormName returns the name of the form value for csrf token.
func (c *csrfProtector) GetFormName() string {
return c.opt.Form
}
// GetToken returns the current token. This is typically used
// to populate a hidden form in an HTML template.
func (c *csrfProtector) GetToken() string {
return c.Token
}
// CsrfOptions maintains options to manage behavior of Generate.
type CsrfOptions struct {
// The global secret value used to generate Tokens.
Secret string
// HTTP header used to set and get token.
Header string
// Form value used to set and get token.
Form string
// Cookie value used to set and get token.
Cookie string
// Cookie domain.
CookieDomain string
// Cookie path.
CookiePath string
CookieHTTPOnly bool
// SameSite set the cookie SameSite type
SameSite http.SameSite
// Key used for getting the unique ID per user.
SessionKey string
// oldSessionKey saves old value corresponding to SessionKey.
oldSessionKey string
// If true, send token via X-Csrf-Token header.
SetHeader bool
// If true, send token via _csrf cookie.
SetCookie bool
// Set the Secure flag to true on the cookie.
Secure bool
// Disallow Origin appear in request header.
Origin bool
// Cookie lifetime. Default is 0
CookieLifeTime int
}
func prepareDefaultCsrfOptions(opt CsrfOptions) CsrfOptions {
if opt.Secret == "" {
randBytes, err := util.CryptoRandomBytes(8)
if err != nil {
// this panic can be handled by the recover() in http handlers
panic(fmt.Errorf("failed to generate random bytes: %w", err))
}
opt.Secret = base32.StdEncoding.EncodeToString(randBytes)
}
if opt.Header == "" {
opt.Header = "X-Csrf-Token"
}
if opt.Form == "" {
opt.Form = "_csrf"
}
if opt.Cookie == "" {
opt.Cookie = "_csrf"
}
if opt.CookiePath == "" {
opt.CookiePath = "/"
}
if opt.SessionKey == "" {
opt.SessionKey = "uid"
}
if opt.CookieLifeTime == 0 {
opt.CookieLifeTime = int(CsrfTokenTimeout.Seconds())
}
opt.oldSessionKey = "_old_" + opt.SessionKey
return opt
}
func newCsrfCookie(c *csrfProtector, value string) *http.Cookie {
return &http.Cookie{
Name: c.opt.Cookie,
Value: value,
Path: c.opt.CookiePath,
Domain: c.opt.CookieDomain,
MaxAge: c.opt.CookieLifeTime,
Secure: c.opt.Secure,
HttpOnly: c.opt.CookieHTTPOnly,
SameSite: c.opt.SameSite,
}
}
// PrepareCSRFProtector returns a CSRFProtector to be used for every request.
// Additionally, depending on options set, generated tokens will be sent via Header and/or Cookie.
func PrepareCSRFProtector(opt CsrfOptions, ctx *Context) CSRFProtector {
opt = prepareDefaultCsrfOptions(opt)
x := &csrfProtector{opt: opt}
if opt.Origin && len(ctx.Req.Header.Get("Origin")) > 0 {
return x
}
x.ID = "0"
uidAny := ctx.Session.Get(opt.SessionKey)
if uidAny != nil {
switch uidVal := uidAny.(type) {
case string:
x.ID = uidVal
case int64:
x.ID = strconv.FormatInt(uidVal, 10)
default:
log.Error("invalid uid type in session: %T", uidAny)
}
}
oldUID := ctx.Session.Get(opt.oldSessionKey)
uidChanged := oldUID == nil || oldUID.(string) != x.ID
cookieToken := ctx.GetSiteCookie(opt.Cookie)
needsNew := true
if uidChanged {
_ = ctx.Session.Set(opt.oldSessionKey, x.ID)
} else if cookieToken != "" {
// If cookie token presents, re-use existing unexpired token, else generate a new one.
if issueTime, ok := ParseCsrfToken(cookieToken); ok {
dur := time.Since(issueTime) // issueTime is not a monotonic-clock, the server time may change a lot to an early time.
if dur >= -CsrfTokenRegenerationInterval && dur <= CsrfTokenRegenerationInterval {
x.Token = cookieToken
needsNew = false
}
}
}
if needsNew {
// FIXME: actionId.
x.Token = GenerateCsrfToken(x.opt.Secret, x.ID, "POST", time.Now())
if opt.SetCookie {
cookie := newCsrfCookie(x, x.Token)
ctx.Resp.Header().Add("Set-Cookie", cookie.String())
}
}
if opt.SetHeader {
ctx.Resp.Header().Add(opt.Header, x.Token)
}
return x
}
func (c *csrfProtector) validateToken(ctx *Context, token string) {
if !ValidCsrfToken(token, c.opt.Secret, c.ID, "POST", time.Now()) {
c.DeleteCookie(ctx)
if middleware.IsAPIPath(ctx.Req) {
// currently, there should be no access to the APIPath with CSRF token. because templates shouldn't use the `/api/` endpoints.
http.Error(ctx.Resp, "Invalid CSRF token.", http.StatusBadRequest)
} else {
ctx.Flash.Error(ctx.Tr("error.invalid_csrf"))
ctx.Redirect(setting.AppSubURL + "/")
}
}
}
// Validate should be used as a per route middleware. It attempts to get a token from an "X-Csrf-Token"
// HTTP header and then a "_csrf" form value. If one of these is found, the token will be validated.
// If this validation fails, custom Error is sent in the reply.
// If neither a header nor form value is found, http.StatusBadRequest is sent.
func (c *csrfProtector) Validate(ctx *Context) {
if token := ctx.Req.Header.Get(c.GetHeaderName()); token != "" {
c.validateToken(ctx, token)
return
}
if token := ctx.Req.FormValue(c.GetFormName()); token != "" {
c.validateToken(ctx, token)
return
}
c.validateToken(ctx, "") // no csrf token, use an empty token to respond error
}
func (c *csrfProtector) DeleteCookie(ctx *Context) {
if c.opt.SetCookie {
cookie := newCsrfCookie(c, "")
cookie.MaxAge = -1
ctx.Resp.Header().Add("Set-Cookie", cookie.String())
}
}
+280
View File
@@ -0,0 +1,280 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2020 The Gitea Authors.
// SPDX-License-Identifier: MIT
package context
import (
"strings"
"code.gitea.io/gitea/models/organization"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/markup"
"code.gitea.io/gitea/modules/markup/markdown"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
)
// Organization contains organization context
type Organization struct {
IsOwner bool
IsMember bool
IsTeamMember bool // Is member of team.
IsTeamAdmin bool // In owner team or team that has admin permission level.
Organization *organization.Organization
OrgLink string
CanCreateOrgRepo bool
PublicMemberOnly bool // Only display public members
Team *organization.Team
Teams []*organization.Team
}
func (org *Organization) CanWriteUnit(ctx *Context, unitType unit.Type) bool {
return org.Organization.UnitPermission(ctx, ctx.Doer, unitType) >= perm.AccessModeWrite
}
func (org *Organization) CanReadUnit(ctx *Context, unitType unit.Type) bool {
return org.Organization.UnitPermission(ctx, ctx.Doer, unitType) >= perm.AccessModeRead
}
func GetOrganizationByParams(ctx *Context) {
orgName := ctx.Params(":org")
var err error
ctx.Org.Organization, err = organization.GetOrgByName(ctx, orgName)
if err != nil {
if organization.IsErrOrgNotExist(err) {
redirectUserID, err := user_model.LookupUserRedirect(ctx, orgName)
if err == nil {
RedirectToUser(ctx.Base, orgName, redirectUserID)
} else if user_model.IsErrUserRedirectNotExist(err) {
ctx.NotFound("GetUserByName", err)
} else {
ctx.ServerError("LookupUserRedirect", err)
}
} else {
ctx.ServerError("GetUserByName", err)
}
return
}
}
// HandleOrgAssignment handles organization assignment
func HandleOrgAssignment(ctx *Context, args ...bool) {
var (
requireMember bool
requireOwner bool
requireTeamMember bool
requireTeamAdmin bool
)
if len(args) >= 1 {
requireMember = args[0]
}
if len(args) >= 2 {
requireOwner = args[1]
}
if len(args) >= 3 {
requireTeamMember = args[2]
}
if len(args) >= 4 {
requireTeamAdmin = args[3]
}
var err error
if ctx.ContextUser == nil {
// if Organization is not defined, get it from params
if ctx.Org.Organization == nil {
GetOrganizationByParams(ctx)
if ctx.Written() {
return
}
}
} else if ctx.ContextUser.IsOrganization() {
if ctx.Org == nil {
ctx.Org = &Organization{}
}
ctx.Org.Organization = (*organization.Organization)(ctx.ContextUser)
} else {
// ContextUser is an individual User
return
}
org := ctx.Org.Organization
// Handle Visibility
if org.Visibility != structs.VisibleTypePublic && !ctx.IsSigned {
// We must be signed in to see limited or private organizations
ctx.NotFound("OrgAssignment", err)
return
}
if org.Visibility == structs.VisibleTypePrivate {
requireMember = true
} else if ctx.IsSigned && ctx.Doer.IsRestricted {
requireMember = true
}
ctx.ContextUser = org.AsUser()
ctx.Data["Org"] = org
// Admin has super access.
if ctx.IsSigned && ctx.Doer.IsAdmin {
ctx.Org.IsOwner = true
ctx.Org.IsMember = true
ctx.Org.IsTeamMember = true
ctx.Org.IsTeamAdmin = true
ctx.Org.CanCreateOrgRepo = true
} else if ctx.IsSigned {
ctx.Org.IsOwner, err = org.IsOwnedBy(ctx, ctx.Doer.ID)
if err != nil {
ctx.ServerError("IsOwnedBy", err)
return
}
if ctx.Org.IsOwner {
ctx.Org.IsMember = true
ctx.Org.IsTeamMember = true
ctx.Org.IsTeamAdmin = true
ctx.Org.CanCreateOrgRepo = true
} else {
ctx.Org.IsMember, err = org.IsOrgMember(ctx, ctx.Doer.ID)
if err != nil {
ctx.ServerError("IsOrgMember", err)
return
}
ctx.Org.CanCreateOrgRepo, err = org.CanCreateOrgRepo(ctx, ctx.Doer.ID)
if err != nil {
ctx.ServerError("CanCreateOrgRepo", err)
return
}
}
} else {
// Fake data.
ctx.Data["SignedUser"] = &user_model.User{}
}
if (requireMember && !ctx.Org.IsMember) ||
(requireOwner && !ctx.Org.IsOwner) {
ctx.NotFound("OrgAssignment", err)
return
}
ctx.Data["IsOrganizationOwner"] = ctx.Org.IsOwner
ctx.Data["IsOrganizationMember"] = ctx.Org.IsMember
ctx.Data["IsPackageEnabled"] = setting.Packages.Enabled
ctx.Data["IsRepoIndexerEnabled"] = setting.Indexer.RepoIndexerEnabled
ctx.Data["IsPublicMember"] = func(uid int64) bool {
is, _ := organization.IsPublicMembership(ctx, ctx.Org.Organization.ID, uid)
return is
}
ctx.Data["CanCreateOrgRepo"] = ctx.Org.CanCreateOrgRepo
ctx.Org.OrgLink = org.AsUser().OrganisationLink()
ctx.Data["OrgLink"] = ctx.Org.OrgLink
// Member
ctx.Org.PublicMemberOnly = ctx.Doer == nil || !ctx.Org.IsMember && !ctx.Doer.IsAdmin
opts := &organization.FindOrgMembersOpts{
OrgID: org.ID,
PublicOnly: ctx.Org.PublicMemberOnly,
}
ctx.Data["NumMembers"], err = organization.CountOrgMembers(ctx, opts)
if err != nil {
ctx.ServerError("CountOrgMembers", err)
return
}
// Team.
if ctx.Org.IsMember {
shouldSeeAllTeams := false
if ctx.Org.IsOwner {
shouldSeeAllTeams = true
} else {
teams, err := org.GetUserTeams(ctx, ctx.Doer.ID)
if err != nil {
ctx.ServerError("GetUserTeams", err)
return
}
for _, team := range teams {
if team.IncludesAllRepositories && team.AccessMode >= perm.AccessModeAdmin {
shouldSeeAllTeams = true
break
}
}
}
if shouldSeeAllTeams {
ctx.Org.Teams, err = org.LoadTeams(ctx)
if err != nil {
ctx.ServerError("LoadTeams", err)
return
}
} else {
ctx.Org.Teams, err = org.GetUserTeams(ctx, ctx.Doer.ID)
if err != nil {
ctx.ServerError("GetUserTeams", err)
return
}
}
ctx.Data["NumTeams"] = len(ctx.Org.Teams)
}
teamName := ctx.Params(":team")
if len(teamName) > 0 {
teamExists := false
for _, team := range ctx.Org.Teams {
if team.LowerName == strings.ToLower(teamName) {
teamExists = true
ctx.Org.Team = team
ctx.Org.IsTeamMember = true
ctx.Data["Team"] = ctx.Org.Team
break
}
}
if !teamExists {
ctx.NotFound("OrgAssignment", err)
return
}
ctx.Data["IsTeamMember"] = ctx.Org.IsTeamMember
if requireTeamMember && !ctx.Org.IsTeamMember {
ctx.NotFound("OrgAssignment", err)
return
}
ctx.Org.IsTeamAdmin = ctx.Org.Team.IsOwnerTeam() || ctx.Org.Team.AccessMode >= perm.AccessModeAdmin
ctx.Data["IsTeamAdmin"] = ctx.Org.IsTeamAdmin
if requireTeamAdmin && !ctx.Org.IsTeamAdmin {
ctx.NotFound("OrgAssignment", err)
return
}
}
ctx.Data["ContextUser"] = ctx.ContextUser
ctx.Data["CanReadProjects"] = ctx.Org.CanReadUnit(ctx, unit.TypeProjects)
ctx.Data["CanReadPackages"] = ctx.Org.CanReadUnit(ctx, unit.TypePackages)
ctx.Data["CanReadCode"] = ctx.Org.CanReadUnit(ctx, unit.TypeCode)
ctx.Data["IsFollowing"] = ctx.Doer != nil && user_model.IsFollowing(ctx, ctx.Doer.ID, ctx.ContextUser.ID)
if len(ctx.ContextUser.Description) != 0 {
content, err := markdown.RenderString(&markup.RenderContext{
Metas: map[string]string{"mode": "document"},
Ctx: ctx,
}, ctx.ContextUser.Description)
if err != nil {
ctx.ServerError("RenderString", err)
return
}
ctx.Data["RenderedDescription"] = content
}
}
// OrgAssignment returns a middleware to handle organization assignment
func OrgAssignment(args ...bool) func(ctx *Context) {
return func(ctx *Context) {
HandleOrgAssignment(ctx, args...)
}
}
+165
View File
@@ -0,0 +1,165 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"fmt"
"net/http"
"code.gitea.io/gitea/models/organization"
packages_model "code.gitea.io/gitea/models/packages"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/templates"
)
// Package contains owner, access mode and optional the package descriptor
type Package struct {
Owner *user_model.User
AccessMode perm.AccessMode
Descriptor *packages_model.PackageDescriptor
}
type packageAssignmentCtx struct {
*Base
Doer *user_model.User
ContextUser *user_model.User
}
// PackageAssignment returns a middleware to handle Context.Package assignment
func PackageAssignment() func(ctx *Context) {
return func(ctx *Context) {
errorFn := func(status int, title string, obj any) {
err, ok := obj.(error)
if !ok {
err = fmt.Errorf("%s", obj)
}
if status == http.StatusNotFound {
ctx.NotFound(title, err)
} else {
ctx.ServerError(title, err)
}
}
paCtx := &packageAssignmentCtx{Base: ctx.Base, Doer: ctx.Doer, ContextUser: ctx.ContextUser}
ctx.Package = packageAssignment(paCtx, errorFn)
}
}
// PackageAssignmentAPI returns a middleware to handle Context.Package assignment
func PackageAssignmentAPI() func(ctx *APIContext) {
return func(ctx *APIContext) {
paCtx := &packageAssignmentCtx{Base: ctx.Base, Doer: ctx.Doer, ContextUser: ctx.ContextUser}
ctx.Package = packageAssignment(paCtx, ctx.Error)
}
}
func packageAssignment(ctx *packageAssignmentCtx, errCb func(int, string, any)) *Package {
pkg := &Package{
Owner: ctx.ContextUser,
}
var err error
pkg.AccessMode, err = determineAccessMode(ctx.Base, pkg, ctx.Doer)
if err != nil {
errCb(http.StatusInternalServerError, "determineAccessMode", err)
return pkg
}
packageType := ctx.Params("type")
name := ctx.Params("name")
version := ctx.Params("version")
if packageType != "" && name != "" && version != "" {
pv, err := packages_model.GetVersionByNameAndVersion(ctx, pkg.Owner.ID, packages_model.Type(packageType), name, version)
if err != nil {
if err == packages_model.ErrPackageNotExist {
errCb(http.StatusNotFound, "GetVersionByNameAndVersion", err)
} else {
errCb(http.StatusInternalServerError, "GetVersionByNameAndVersion", err)
}
return pkg
}
pkg.Descriptor, err = packages_model.GetPackageDescriptor(ctx, pv)
if err != nil {
errCb(http.StatusInternalServerError, "GetPackageDescriptor", err)
return pkg
}
}
return pkg
}
func determineAccessMode(ctx *Base, pkg *Package, doer *user_model.User) (perm.AccessMode, error) {
if setting.Service.RequireSignInView && (doer == nil || doer.IsGhost()) {
return perm.AccessModeNone, nil
}
if doer != nil && !doer.IsGhost() && (!doer.IsActive || doer.ProhibitLogin) {
return perm.AccessModeNone, nil
}
// TODO: ActionUser permission check
accessMode := perm.AccessModeNone
if pkg.Owner.IsOrganization() {
org := organization.OrgFromUser(pkg.Owner)
if doer != nil && !doer.IsGhost() {
// 1. If user is logged in, check all team packages permissions
var err error
accessMode, err = org.GetOrgUserMaxAuthorizeLevel(ctx, doer.ID)
if err != nil {
return accessMode, err
}
// If access mode is less than write check every team for more permissions
// The minimum possible access mode is read for org members
if accessMode < perm.AccessModeWrite {
teams, err := organization.GetUserOrgTeams(ctx, org.ID, doer.ID)
if err != nil {
return accessMode, err
}
for _, t := range teams {
perm := t.UnitAccessMode(ctx, unit.TypePackages)
if accessMode < perm {
accessMode = perm
}
}
}
}
if accessMode == perm.AccessModeNone && organization.HasOrgOrUserVisible(ctx, pkg.Owner, doer) {
// 2. If user is unauthorized or no org member, check if org is visible
accessMode = perm.AccessModeRead
}
} else {
if doer != nil && !doer.IsGhost() {
// 1. Check if user is package owner
if doer.ID == pkg.Owner.ID {
accessMode = perm.AccessModeOwner
} else if pkg.Owner.Visibility == structs.VisibleTypePublic || pkg.Owner.Visibility == structs.VisibleTypeLimited { // 2. Check if package owner is public or limited
accessMode = perm.AccessModeRead
}
} else if pkg.Owner.Visibility == structs.VisibleTypePublic { // 3. Check if package owner is public
accessMode = perm.AccessModeRead
}
}
return accessMode, nil
}
// PackageContexter initializes a package context for a request.
func PackageContexter() func(next http.Handler) http.Handler {
renderer := templates.HTMLRenderer()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
base, baseCleanUp := NewBaseContext(resp, req)
defer baseCleanUp()
// it is still needed when rendering 500 page in a package handler
ctx := NewWebContext(base, renderer, nil)
ctx.Base.AppendContextValue(WebContextKey, ctx)
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}
+52
View File
@@ -0,0 +1,52 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"fmt"
"html/template"
"net/url"
"strings"
"code.gitea.io/gitea/modules/paginator"
)
// Pagination provides a pagination via paginator.Paginator and additional configurations for the link params used in rendering
type Pagination struct {
Paginater *paginator.Paginator
urlParams []string
}
// NewPagination creates a new instance of the Pagination struct.
// "pagingNum" is "page size" or "limit", "current" is "page"
func NewPagination(total, pagingNum, current, numPages int) *Pagination {
p := &Pagination{}
p.Paginater = paginator.New(total, pagingNum, current, numPages)
return p
}
// AddParamString adds a string parameter directly
func (p *Pagination) AddParamString(key, value string) {
urlParam := fmt.Sprintf("%s=%v", url.QueryEscape(key), url.QueryEscape(value))
p.urlParams = append(p.urlParams, urlParam)
}
// GetParams returns the configured URL params
func (p *Pagination) GetParams() template.URL {
return template.URL(strings.Join(p.urlParams, "&"))
}
// SetDefaultParams sets common pagination params that are often used
func (p *Pagination) SetDefaultParams(ctx *Context) {
if v, ok := ctx.Data["SortType"].(string); ok {
p.AddParamString("sort", v)
}
if v, ok := ctx.Data["Keyword"].(string); ok {
p.AddParamString("q", v)
}
if v, ok := ctx.Data["IsFuzzy"].(bool); ok {
p.AddParamString("fuzzy", fmt.Sprint(v))
}
// do not add any more uncommon params here!
}
+149
View File
@@ -0,0 +1,149 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"net/http"
auth_model "code.gitea.io/gitea/models/auth"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
"code.gitea.io/gitea/modules/log"
)
// RequireRepoAdmin returns a middleware for requiring repository admin permission
func RequireRepoAdmin() func(ctx *Context) {
return func(ctx *Context) {
if !ctx.IsSigned || !ctx.Repo.IsAdmin() {
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
return
}
}
}
// RequireRepoWriter returns a middleware for requiring repository write to the specify unitType
func RequireRepoWriter(unitType unit.Type) func(ctx *Context) {
return func(ctx *Context) {
if !ctx.Repo.CanWrite(unitType) {
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
return
}
}
}
// CanEnableEditor checks if the user is allowed to write to the branch of the repo
func CanEnableEditor() func(ctx *Context) {
return func(ctx *Context) {
if !ctx.Repo.CanWriteToBranch(ctx, ctx.Doer, ctx.Repo.BranchName) {
ctx.NotFound("CanWriteToBranch denies permission", nil)
return
}
}
}
// RequireRepoWriterOr returns a middleware for requiring repository write to one of the unit permission
func RequireRepoWriterOr(unitTypes ...unit.Type) func(ctx *Context) {
return func(ctx *Context) {
for _, unitType := range unitTypes {
if ctx.Repo.CanWrite(unitType) {
return
}
}
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
}
}
// RequireRepoReader returns a middleware for requiring repository read to the specify unitType
func RequireRepoReader(unitType unit.Type) func(ctx *Context) {
return func(ctx *Context) {
if !ctx.Repo.CanRead(unitType) {
if log.IsTrace() {
if ctx.IsSigned {
log.Trace("Permission Denied: User %-v cannot read %-v in Repo %-v\n"+
"User in Repo has Permissions: %-+v",
ctx.Doer,
unitType,
ctx.Repo.Repository,
ctx.Repo.Permission)
} else {
log.Trace("Permission Denied: Anonymous user cannot read %-v in Repo %-v\n"+
"Anonymous user in Repo has Permissions: %-+v",
unitType,
ctx.Repo.Repository,
ctx.Repo.Permission)
}
}
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
return
}
}
}
// RequireRepoReaderOr returns a middleware for requiring repository write to one of the unit permission
func RequireRepoReaderOr(unitTypes ...unit.Type) func(ctx *Context) {
return func(ctx *Context) {
for _, unitType := range unitTypes {
if ctx.Repo.CanRead(unitType) {
return
}
}
if log.IsTrace() {
var format string
var args []any
if ctx.IsSigned {
format = "Permission Denied: User %-v cannot read ["
args = append(args, ctx.Doer)
} else {
format = "Permission Denied: Anonymous user cannot read ["
}
for _, unit := range unitTypes {
format += "%-v, "
args = append(args, unit)
}
format = format[:len(format)-2] + "] in Repo %-v\n" +
"User in Repo has Permissions: %-+v"
args = append(args, ctx.Repo.Repository, ctx.Repo.Permission)
log.Trace(format, args...)
}
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
}
}
// CheckRepoScopedToken check whether personal access token has repo scope
func CheckRepoScopedToken(ctx *Context, repo *repo_model.Repository, level auth_model.AccessTokenScopeLevel) {
if !ctx.IsBasicAuth || ctx.Data["IsApiToken"] != true {
return
}
scope, ok := ctx.Data["ApiTokenScope"].(auth_model.AccessTokenScope)
if ok { // it's a personal access token but not oauth2 token
var scopeMatched bool
requiredScopes := auth_model.GetRequiredScopes(level, auth_model.AccessTokenScopeCategoryRepository)
// check if scope only applies to public resources
publicOnly, err := scope.PublicOnly()
if err != nil {
ctx.ServerError("HasScope", err)
return
}
if publicOnly && repo.IsPrivate {
ctx.Error(http.StatusForbidden)
return
}
scopeMatched, err = scope.HasScope(requiredScopes...)
if err != nil {
ctx.ServerError("HasScope", err)
return
}
if !scopeMatched {
ctx.Error(http.StatusForbidden)
return
}
}
}
+85
View File
@@ -0,0 +1,85 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"context"
"fmt"
"net/http"
"time"
"code.gitea.io/gitea/modules/graceful"
"code.gitea.io/gitea/modules/process"
"code.gitea.io/gitea/modules/web"
web_types "code.gitea.io/gitea/modules/web/types"
)
// PrivateContext represents a context for private routes
type PrivateContext struct {
*Base
Override context.Context
Repo *Repository
}
func init() {
web.RegisterResponseStatusProvider[*PrivateContext](func(req *http.Request) web_types.ResponseStatusProvider {
return req.Context().Value(privateContextKey).(*PrivateContext)
})
}
// Deadline is part of the interface for context.Context and we pass this to the request context
func (ctx *PrivateContext) Deadline() (deadline time.Time, ok bool) {
if ctx.Override != nil {
return ctx.Override.Deadline()
}
return ctx.Base.Deadline()
}
// Done is part of the interface for context.Context and we pass this to the request context
func (ctx *PrivateContext) Done() <-chan struct{} {
if ctx.Override != nil {
return ctx.Override.Done()
}
return ctx.Base.Done()
}
// Err is part of the interface for context.Context and we pass this to the request context
func (ctx *PrivateContext) Err() error {
if ctx.Override != nil {
return ctx.Override.Err()
}
return ctx.Base.Err()
}
var privateContextKey any = "default_private_context"
// GetPrivateContext returns a context for Private routes
func GetPrivateContext(req *http.Request) *PrivateContext {
return req.Context().Value(privateContextKey).(*PrivateContext)
}
// PrivateContexter returns apicontext as middleware
func PrivateContexter() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
base, baseCleanUp := NewBaseContext(w, req)
ctx := &PrivateContext{Base: base}
defer baseCleanUp()
ctx.Base.AppendContextValue(privateContextKey, ctx)
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}
// OverrideContext overrides the underlying request context for Done() etc.
// This function should be used when there is a need for work to continue even if the request has been cancelled.
// Primarily this affects hook/post-receive and hook/proc-receive both of which need to continue working even if
// the underlying request has timed out from the ssh/http push
func OverrideContext(ctx *PrivateContext) (cancel context.CancelFunc) {
// We now need to override the request context as the base for our work because even if the request is cancelled we have to continue this work
ctx.Override, _, cancel = process.GetManager().AddTypedContext(graceful.GetManager().HammerContext(), fmt.Sprintf("PrivateContext: %s", ctx.Req.RequestURI), process.RequestProcessType, true)
return cancel
}
File diff suppressed because it is too large Load Diff
+103
View File
@@ -0,0 +1,103 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"net/http"
web_types "code.gitea.io/gitea/modules/web/types"
)
// ResponseWriter represents a response writer for HTTP
type ResponseWriter interface {
http.ResponseWriter
http.Flusher
web_types.ResponseStatusProvider
Before(func(ResponseWriter))
Status() int // used by access logger template
Size() int // used by access logger template
}
var _ ResponseWriter = &Response{}
// Response represents a response
type Response struct {
http.ResponseWriter
written int
status int
befores []func(ResponseWriter)
beforeExecuted bool
}
// Write writes bytes to HTTP endpoint
func (r *Response) Write(bs []byte) (int, error) {
if !r.beforeExecuted {
for _, before := range r.befores {
before(r)
}
r.beforeExecuted = true
}
size, err := r.ResponseWriter.Write(bs)
r.written += size
if err != nil {
return size, err
}
if r.status == 0 {
r.status = http.StatusOK
}
return size, nil
}
func (r *Response) Status() int {
return r.status
}
func (r *Response) Size() int {
return r.written
}
// WriteHeader write status code
func (r *Response) WriteHeader(statusCode int) {
if !r.beforeExecuted {
for _, before := range r.befores {
before(r)
}
r.beforeExecuted = true
}
if r.status == 0 {
r.status = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
}
// Flush flushes cached data
func (r *Response) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// WrittenStatus returned status code written
func (r *Response) WrittenStatus() int {
return r.status
}
// Before allows for a function to be called before the ResponseWriter has been written to. This is
// useful for setting headers or any other operations that must happen before a response has been written.
func (r *Response) Before(f func(ResponseWriter)) {
r.befores = append(r.befores, f)
}
func WrapResponseWriter(resp http.ResponseWriter) *Response {
if v, ok := resp.(*Response); ok {
return v
}
return &Response{
ResponseWriter: resp,
status: 0,
befores: make([]func(ResponseWriter), 0),
}
}
+105
View File
@@ -0,0 +1,105 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package upload
import (
"mime"
"net/http"
"net/url"
"path"
"regexp"
"strings"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/services/context"
)
// ErrFileTypeForbidden not allowed file type error
type ErrFileTypeForbidden struct {
Type string
}
// IsErrFileTypeForbidden checks if an error is a ErrFileTypeForbidden.
func IsErrFileTypeForbidden(err error) bool {
_, ok := err.(ErrFileTypeForbidden)
return ok
}
func (err ErrFileTypeForbidden) Error() string {
return "This file extension or type is not allowed to be uploaded."
}
var wildcardTypeRe = regexp.MustCompile(`^[a-z]+/\*$`)
// Verify validates whether a file is allowed to be uploaded.
func Verify(buf []byte, fileName, allowedTypesStr string) error {
allowedTypesStr = strings.ReplaceAll(allowedTypesStr, "|", ",") // compat for old config format
allowedTypes := []string{}
for _, entry := range strings.Split(allowedTypesStr, ",") {
entry = strings.ToLower(strings.TrimSpace(entry))
if entry != "" {
allowedTypes = append(allowedTypes, entry)
}
}
if len(allowedTypes) == 0 {
return nil // everything is allowed
}
fullMimeType := http.DetectContentType(buf)
mimeType, _, err := mime.ParseMediaType(fullMimeType)
if err != nil {
log.Warn("Detected attachment type could not be parsed %s", fullMimeType)
return ErrFileTypeForbidden{Type: fullMimeType}
}
extension := strings.ToLower(path.Ext(fileName))
// https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file#Unique_file_type_specifiers
for _, allowEntry := range allowedTypes {
if allowEntry == "*/*" {
return nil // everything allowed
} else if strings.HasPrefix(allowEntry, ".") && allowEntry == extension {
return nil // extension is allowed
} else if mimeType == allowEntry {
return nil // mime type is allowed
} else if wildcardTypeRe.MatchString(allowEntry) && strings.HasPrefix(mimeType, allowEntry[:len(allowEntry)-1]) {
return nil // wildcard match, e.g. image/*
}
}
log.Info("Attachment with type %s blocked from upload", fullMimeType)
return ErrFileTypeForbidden{Type: fullMimeType}
}
// AddUploadContext renders template values for dropzone
func AddUploadContext(ctx *context.Context, uploadType string) {
if uploadType == "release" {
ctx.Data["UploadUrl"] = ctx.Repo.RepoLink + "/releases/attachments"
ctx.Data["UploadRemoveUrl"] = ctx.Repo.RepoLink + "/releases/attachments/remove"
ctx.Data["UploadLinkUrl"] = ctx.Repo.RepoLink + "/releases/attachments"
ctx.Data["UploadAccepts"] = strings.ReplaceAll(setting.Repository.Release.AllowedTypes, "|", ",")
ctx.Data["UploadMaxFiles"] = setting.Attachment.MaxFiles
ctx.Data["UploadMaxSize"] = setting.Attachment.MaxSize
} else if uploadType == "comment" {
ctx.Data["UploadUrl"] = ctx.Repo.RepoLink + "/issues/attachments"
ctx.Data["UploadRemoveUrl"] = ctx.Repo.RepoLink + "/issues/attachments/remove"
if len(ctx.Params(":index")) > 0 {
ctx.Data["UploadLinkUrl"] = ctx.Repo.RepoLink + "/issues/" + url.PathEscape(ctx.Params(":index")) + "/attachments"
} else {
ctx.Data["UploadLinkUrl"] = ctx.Repo.RepoLink + "/issues/attachments"
}
ctx.Data["UploadAccepts"] = strings.ReplaceAll(setting.Attachment.AllowedTypes, "|", ",")
ctx.Data["UploadMaxFiles"] = setting.Attachment.MaxFiles
ctx.Data["UploadMaxSize"] = setting.Attachment.MaxSize
} else if uploadType == "repo" {
ctx.Data["UploadUrl"] = ctx.Repo.RepoLink + "/upload-file"
ctx.Data["UploadRemoveUrl"] = ctx.Repo.RepoLink + "/upload-remove"
ctx.Data["UploadLinkUrl"] = ctx.Repo.RepoLink + "/upload-file"
ctx.Data["UploadAccepts"] = strings.ReplaceAll(setting.Repository.Upload.AllowedTypes, "|", ",")
ctx.Data["UploadMaxFiles"] = setting.Repository.Upload.MaxFiles
ctx.Data["UploadMaxSize"] = setting.Repository.Upload.FileMaxSize
}
}
+194
View File
@@ -0,0 +1,194 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package upload
import (
"bytes"
"compress/gzip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUpload(t *testing.T) {
testContent := []byte(`This is a plain text file.`)
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write(testContent)
w.Close()
kases := []struct {
data []byte
fileName string
allowedTypes string
err error
}{
{
data: testContent,
fileName: "test.txt",
allowedTypes: "",
err: nil,
},
{
data: testContent,
fileName: "dir/test.txt",
allowedTypes: "",
err: nil,
},
{
data: testContent,
fileName: "../../../test.txt",
allowedTypes: "",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: ",",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "|",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "*/*",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "*/*,",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "*/*|",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "text/plain",
err: nil,
},
{
data: testContent,
fileName: "dir/test.txt",
allowedTypes: "text/plain",
err: nil,
},
{
data: testContent,
fileName: "/dir.txt/test.js",
allowedTypes: ".js",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: " text/plain ",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: ".txt",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: " .txt,.js",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: " .txt|.js",
err: nil,
},
{
data: testContent,
fileName: "../../test.txt",
allowedTypes: " .txt|.js",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: " .txt ,.js ",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "text/plain, .txt",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "text/*",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "text/*,.js",
err: nil,
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "text/**",
err: ErrFileTypeForbidden{"text/plain; charset=utf-8"},
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: "application/x-gzip",
err: ErrFileTypeForbidden{"text/plain; charset=utf-8"},
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: ".zip",
err: ErrFileTypeForbidden{"text/plain; charset=utf-8"},
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: ".zip,.txtx",
err: ErrFileTypeForbidden{"text/plain; charset=utf-8"},
},
{
data: testContent,
fileName: "test.txt",
allowedTypes: ".zip|.txtx",
err: ErrFileTypeForbidden{"text/plain; charset=utf-8"},
},
{
data: b.Bytes(),
fileName: "test.txt",
allowedTypes: "application/x-gzip",
err: nil,
},
}
for _, kase := range kases {
assert.Equal(t, kase.err, Verify(kase.data, kase.fileName, kase.allowedTypes))
}
}
+8 -9
View File
@@ -9,12 +9,11 @@ import (
"strings"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/context"
)
// UserAssignmentWeb returns a middleware to handle context-user assignment for web routes
func UserAssignmentWeb() func(ctx *context.Context) {
return func(ctx *context.Context) {
func UserAssignmentWeb() func(ctx *Context) {
return func(ctx *Context) {
errorFn := func(status int, title string, obj any) {
err, ok := obj.(error)
if !ok {
@@ -32,8 +31,8 @@ func UserAssignmentWeb() func(ctx *context.Context) {
}
// UserIDAssignmentAPI returns a middleware to handle context-user assignment for api routes
func UserIDAssignmentAPI() func(ctx *context.APIContext) {
return func(ctx *context.APIContext) {
func UserIDAssignmentAPI() func(ctx *APIContext) {
return func(ctx *APIContext) {
userID := ctx.ParamsInt64(":user-id")
if ctx.IsSigned && ctx.Doer.ID == userID {
@@ -53,13 +52,13 @@ func UserIDAssignmentAPI() func(ctx *context.APIContext) {
}
// UserAssignmentAPI returns a middleware to handle context-user assignment for api routes
func UserAssignmentAPI() func(ctx *context.APIContext) {
return func(ctx *context.APIContext) {
func UserAssignmentAPI() func(ctx *APIContext) {
return func(ctx *APIContext) {
ctx.ContextUser = userAssignment(ctx.Base, ctx.Doer, ctx.Error)
}
}
func userAssignment(ctx *context.Base, doer *user_model.User, errCb func(int, string, any)) (contextUser *user_model.User) {
func userAssignment(ctx *Base, doer *user_model.User, errCb func(int, string, any)) (contextUser *user_model.User) {
username := ctx.Params(":username")
if doer != nil && doer.LowerName == strings.ToLower(username) {
@@ -70,7 +69,7 @@ func userAssignment(ctx *context.Base, doer *user_model.User, errCb func(int, st
if err != nil {
if user_model.IsErrUserNotExist(err) {
if redirectUserID, err := user_model.LookupUserRedirect(ctx, username); err == nil {
context.RedirectToUser(ctx, username, redirectUserID)
RedirectToUser(ctx, username, redirectUserID)
} else if user_model.IsErrUserRedirectNotExist(err) {
errCb(http.StatusNotFound, "GetUserByName", err)
} else {
+38
View File
@@ -0,0 +1,38 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package context
import (
"strings"
"time"
)
// GetQueryBeforeSince return parsed time (unix format) from URL query's before and since
func GetQueryBeforeSince(ctx *Base) (before, since int64, err error) {
before, err = parseFormTime(ctx, "before")
if err != nil {
return 0, 0, err
}
since, err = parseFormTime(ctx, "since")
if err != nil {
return 0, 0, err
}
return before, since, nil
}
// parseTime parse time and return unix timestamp
func parseFormTime(ctx *Base, name string) (int64, error) {
value := strings.TrimSpace(ctx.FormString(name))
if len(value) != 0 {
t, err := time.Parse(time.RFC3339, value)
if err != nil {
return 0, err
}
if !t.IsZero() {
return t.Unix(), nil
}
}
return 0, nil
}
+99
View File
@@ -0,0 +1,99 @@
// Copyright 2012 Google Inc. All Rights Reserved.
// Copyright 2014 The Macaron Authors
// Copyright 2020 The Gitea Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// SPDX-License-Identifier: Apache-2.0
package context
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"fmt"
"strconv"
"strings"
"time"
)
// CsrfTokenTimeout represents the duration that XSRF tokens are valid.
// It is exported so clients may set cookie timeouts that match generated tokens.
const CsrfTokenTimeout = 24 * time.Hour
// CsrfTokenRegenerationInterval is the interval between token generations, old tokens are still valid before CsrfTokenTimeout
var CsrfTokenRegenerationInterval = 10 * time.Minute
var csrfTokenSep = []byte(":")
// GenerateCsrfToken returns a URL-safe secure XSRF token that expires in CsrfTokenTimeout hours.
// key is a secret key for your application.
// userID is a unique identifier for the user.
// actionID is the action the user is taking (e.g. POSTing to a particular path).
func GenerateCsrfToken(key, userID, actionID string, now time.Time) string {
nowUnixNano := now.UnixNano()
nowUnixNanoStr := strconv.FormatInt(nowUnixNano, 10)
h := hmac.New(sha1.New, []byte(key))
h.Write([]byte(strings.ReplaceAll(userID, ":", "_")))
h.Write(csrfTokenSep)
h.Write([]byte(strings.ReplaceAll(actionID, ":", "_")))
h.Write(csrfTokenSep)
h.Write([]byte(nowUnixNanoStr))
tok := fmt.Sprintf("%s:%s", h.Sum(nil), nowUnixNanoStr)
return base64.RawURLEncoding.EncodeToString([]byte(tok))
}
func ParseCsrfToken(token string) (issueTime time.Time, ok bool) {
data, err := base64.RawURLEncoding.DecodeString(token)
if err != nil {
return time.Time{}, false
}
pos := bytes.LastIndex(data, csrfTokenSep)
if pos == -1 {
return time.Time{}, false
}
nanos, err := strconv.ParseInt(string(data[pos+1:]), 10, 64)
if err != nil {
return time.Time{}, false
}
return time.Unix(0, nanos), true
}
// ValidCsrfToken returns true if token is a valid and unexpired token returned by Generate.
func ValidCsrfToken(token, key, userID, actionID string, now time.Time) bool {
issueTime, ok := ParseCsrfToken(token)
if !ok {
return false
}
// Check that the token is not expired.
if now.Sub(issueTime) >= CsrfTokenTimeout {
return false
}
// Check that the token is not from the future.
// Allow 1-minute grace period in case the token is being verified on a
// machine whose clock is behind the machine that issued the token.
if issueTime.After(now.Add(1 * time.Minute)) {
return false
}
expected := GenerateCsrfToken(key, userID, actionID, issueTime)
// Check that the token matches the expected value.
// Use constant time comparison to avoid timing attacks.
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
+91
View File
@@ -0,0 +1,91 @@
// Copyright 2012 Google Inc. All Rights Reserved.
// Copyright 2014 The Macaron Authors
// Copyright 2020 The Gitea Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// SPDX-License-Identifier: Apache-2.0
package context
import (
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const (
key = "quay"
userID = "12345678"
actionID = "POST /form"
)
var (
now = time.Now()
oneMinuteFromNow = now.Add(1 * time.Minute)
)
func Test_ValidToken(t *testing.T) {
t.Run("Validate token", func(t *testing.T) {
tok := GenerateCsrfToken(key, userID, actionID, now)
assert.True(t, ValidCsrfToken(tok, key, userID, actionID, oneMinuteFromNow))
assert.True(t, ValidCsrfToken(tok, key, userID, actionID, now.Add(CsrfTokenTimeout-1*time.Nanosecond)))
assert.True(t, ValidCsrfToken(tok, key, userID, actionID, now.Add(-1*time.Minute)))
})
}
// Test_SeparatorReplacement tests that separators are being correctly substituted
func Test_SeparatorReplacement(t *testing.T) {
t.Run("Test two separator replacements", func(t *testing.T) {
assert.NotEqual(t, GenerateCsrfToken("foo:bar", "baz", "wah", now),
GenerateCsrfToken("foo", "bar:baz", "wah", now))
})
}
func Test_InvalidToken(t *testing.T) {
t.Run("Test invalid tokens", func(t *testing.T) {
invalidTokenTests := []struct {
name, key, userID, actionID string
t time.Time
}{
{"Bad key", "foobar", userID, actionID, oneMinuteFromNow},
{"Bad userID", key, "foobar", actionID, oneMinuteFromNow},
{"Bad actionID", key, userID, "foobar", oneMinuteFromNow},
{"Expired", key, userID, actionID, now.Add(CsrfTokenTimeout)},
{"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)},
}
tok := GenerateCsrfToken(key, userID, actionID, now)
for _, itt := range invalidTokenTests {
assert.False(t, ValidCsrfToken(tok, itt.key, itt.userID, itt.actionID, itt.t))
}
})
}
// Test_ValidateBadData primarily tests that no unexpected panics are triggered during parsing
func Test_ValidateBadData(t *testing.T) {
t.Run("Validate bad data", func(t *testing.T) {
badDataTests := []struct {
name, tok string
}{
{"Invalid Base64", "ASDab24(@)$*=="},
{"No delimiter", base64.URLEncoding.EncodeToString([]byte("foobar12345678"))},
{"Invalid time", base64.URLEncoding.EncodeToString([]byte("foobar:foobar"))},
}
for _, bdt := range badDataTests {
assert.False(t, ValidCsrfToken(bdt.tok, key, userID, actionID, oneMinuteFromNow))
}
})
}
+170
View File
@@ -0,0 +1,170 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
// Package contexttest provides utilities for testing Web/API contexts with models.
package contexttest
import (
gocontext "context"
"io"
"maps"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
access_model "code.gitea.io/gitea/models/perm/access"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/templates"
"code.gitea.io/gitea/modules/translation"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/services/context"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
)
func mockRequest(t *testing.T, reqPath string) *http.Request {
method, path, found := strings.Cut(reqPath, " ")
if !found {
method = "GET"
path = reqPath
}
requestURL, err := url.Parse(path)
assert.NoError(t, err)
req := &http.Request{Method: method, URL: requestURL, Form: maps.Clone(requestURL.Query()), Header: http.Header{}}
req = req.WithContext(middleware.WithContextData(req.Context()))
return req
}
type MockContextOption struct {
Render context.Render
}
// MockContext mock context for unit tests
func MockContext(t *testing.T, reqPath string, opts ...MockContextOption) (*context.Context, *httptest.ResponseRecorder) {
var opt MockContextOption
if len(opts) > 0 {
opt = opts[0]
}
if opt.Render == nil {
opt.Render = &MockRender{}
}
resp := httptest.NewRecorder()
req := mockRequest(t, reqPath)
base, baseCleanUp := context.NewBaseContext(resp, req)
_ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later
base.Data = middleware.GetContextData(req.Context())
base.Locale = &translation.MockLocale{}
ctx := context.NewWebContext(base, opt.Render, nil)
ctx.AppendContextValue(context.WebContextKey, ctx)
ctx.PageData = map[string]any{}
ctx.Data["PageStartTime"] = time.Now()
chiCtx := chi.NewRouteContext()
ctx.Base.AppendContextValue(chi.RouteCtxKey, chiCtx)
return ctx, resp
}
// MockAPIContext mock context for unit tests
func MockAPIContext(t *testing.T, reqPath string) (*context.APIContext, *httptest.ResponseRecorder) {
resp := httptest.NewRecorder()
req := mockRequest(t, reqPath)
base, baseCleanUp := context.NewBaseContext(resp, req)
base.Data = middleware.GetContextData(req.Context())
base.Locale = &translation.MockLocale{}
ctx := &context.APIContext{Base: base}
_ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later
chiCtx := chi.NewRouteContext()
ctx.Base.AppendContextValue(chi.RouteCtxKey, chiCtx)
return ctx, resp
}
// LoadRepo load a repo into a test context.
func LoadRepo(t *testing.T, ctx gocontext.Context, repoID int64) {
var doer *user_model.User
repo := &context.Repository{}
switch ctx := ctx.(type) {
case *context.Context:
ctx.Repo = repo
doer = ctx.Doer
case *context.APIContext:
ctx.Repo = repo
doer = ctx.Doer
default:
assert.FailNow(t, "context is not *context.Context or *context.APIContext")
}
repo.Repository = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID})
var err error
repo.Owner, err = user_model.GetUserByID(ctx, repo.Repository.OwnerID)
assert.NoError(t, err)
repo.RepoLink = repo.Repository.Link()
repo.Permission, err = access_model.GetUserRepoPermission(ctx, repo.Repository, doer)
assert.NoError(t, err)
}
// LoadRepoCommit loads a repo's commit into a test context.
func LoadRepoCommit(t *testing.T, ctx gocontext.Context) {
var repo *context.Repository
switch ctx := ctx.(type) {
case *context.Context:
repo = ctx.Repo
case *context.APIContext:
repo = ctx.Repo
default:
assert.FailNow(t, "context is not *context.Context or *context.APIContext")
}
gitRepo, err := gitrepo.OpenRepository(ctx, repo.Repository)
assert.NoError(t, err)
defer gitRepo.Close()
branch, err := gitRepo.GetHEADBranch()
assert.NoError(t, err)
assert.NotNil(t, branch)
if branch != nil {
repo.Commit, err = gitRepo.GetBranchCommit(branch.Name)
assert.NoError(t, err)
}
}
// LoadUser load a user into a test context
func LoadUser(t *testing.T, ctx gocontext.Context, userID int64) {
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: userID})
switch ctx := ctx.(type) {
case *context.Context:
ctx.Doer = doer
case *context.APIContext:
ctx.Doer = doer
default:
assert.FailNow(t, "context is not *context.Context or *context.APIContext")
}
}
// LoadGitRepo load a git repo into a test context. Requires that ctx.Repo has
// already been populated.
func LoadGitRepo(t *testing.T, ctx *context.Context) {
assert.NoError(t, ctx.Repo.Repository.LoadOwner(ctx))
var err error
ctx.Repo.GitRepo, err = gitrepo.OpenRepository(ctx, ctx.Repo.Repository)
assert.NoError(t, err)
}
type MockRender struct{}
func (tr *MockRender) TemplateLookup(tmpl string, _ gocontext.Context) (templates.TemplateExecutor, error) {
return nil, nil
}
func (tr *MockRender) HTML(w io.Writer, status int, _ string, _ any, _ gocontext.Context) error {
if resp, ok := w.(http.ResponseWriter); ok {
resp.WriteHeader(status)
}
return nil
}
+19 -20
View File
@@ -159,6 +159,7 @@ func ToBranchProtection(ctx context.Context, bp *git_model.ProtectedBranch) *api
BlockOnOfficialReviewRequests: bp.BlockOnOfficialReviewRequests,
BlockOnOutdatedBranch: bp.BlockOnOutdatedBranch,
DismissStaleApprovals: bp.DismissStaleApprovals,
IgnoreStaleApprovals: bp.IgnoreStaleApprovals,
RequireSignedCommits: bp.RequireSignedCommits,
ProtectedFilePatterns: bp.ProtectedFilePatterns,
UnprotectedFilePatterns: bp.UnprotectedFilePatterns,
@@ -332,40 +333,38 @@ func ToTeam(ctx context.Context, team *organization.Team, loadOrg ...bool) (*api
// ToTeams convert models.Team list to api.Team list
func ToTeams(ctx context.Context, teams []*organization.Team, loadOrgs bool) ([]*api.Team, error) {
if len(teams) == 0 || teams[0] == nil {
return nil, nil
}
cache := make(map[int64]*api.Organization)
apiTeams := make([]*api.Team, len(teams))
for i := range teams {
if err := teams[i].LoadUnits(ctx); err != nil {
apiTeams := make([]*api.Team, 0, len(teams))
for _, t := range teams {
if err := t.LoadUnits(ctx); err != nil {
return nil, err
}
apiTeams[i] = &api.Team{
ID: teams[i].ID,
Name: teams[i].Name,
Description: teams[i].Description,
IncludesAllRepositories: teams[i].IncludesAllRepositories,
CanCreateOrgRepo: teams[i].CanCreateOrgRepo,
Permission: teams[i].AccessMode.String(),
Units: teams[i].GetUnitNames(),
UnitsMap: teams[i].GetUnitsMap(),
apiTeam := &api.Team{
ID: t.ID,
Name: t.Name,
Description: t.Description,
IncludesAllRepositories: t.IncludesAllRepositories,
CanCreateOrgRepo: t.CanCreateOrgRepo,
Permission: t.AccessMode.String(),
Units: t.GetUnitNames(),
UnitsMap: t.GetUnitsMap(),
}
if loadOrgs {
apiOrg, ok := cache[teams[i].OrgID]
apiOrg, ok := cache[t.OrgID]
if !ok {
org, err := organization.GetOrgByID(ctx, teams[i].OrgID)
org, err := organization.GetOrgByID(ctx, t.OrgID)
if err != nil {
return nil, err
}
apiOrg = ToOrganization(ctx, org)
cache[teams[i].OrgID] = apiOrg
cache[t.OrgID] = apiOrg
}
apiTeams[i].Organization = apiOrg
apiTeam.Organization = apiOrg
}
apiTeams = append(apiTeams, apiTeam)
}
return apiTeams, nil
}
+1 -1
View File
@@ -10,11 +10,11 @@ import (
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
ctx "code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
api "code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/util"
ctx "code.gitea.io/gitea/services/context"
"code.gitea.io/gitea/services/gitdiff"
)
+5 -5
View File
@@ -19,12 +19,12 @@ import (
func TestToCommitMeta(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
headRepo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
sha1, _ := git.NewIDFromString("0000000000000000000000000000000000000000")
sha1 := git.Sha1ObjectFormat
signature := &git.Signature{Name: "Test Signature", Email: "test@email.com", When: time.Unix(0, 0)}
tag := &git.Tag{
Name: "Test Tag",
ID: sha1,
Object: sha1,
ID: sha1.EmptyObjectID(),
Object: sha1.EmptyObjectID(),
Type: "Test Type",
Tagger: signature,
Message: "Test Message",
@@ -34,8 +34,8 @@ func TestToCommitMeta(t *testing.T) {
assert.NotNil(t, commitMeta)
assert.EqualValues(t, &api.CommitMeta{
SHA: "0000000000000000000000000000000000000000",
URL: util.URLJoin(headRepo.APIURL(), "git/commits", "0000000000000000000000000000000000000000"),
SHA: sha1.EmptyObjectID().String(),
URL: util.URLJoin(headRepo.APIURL(), "git/commits", sha1.EmptyObjectID().String()),
Created: time.Unix(0, 0),
}, commitMeta)
}
+2 -1
View File
@@ -98,7 +98,8 @@ func toIssue(ctx context.Context, issue *issues_model.Issue, getDownloadURL func
}
if issue.PullRequest != nil {
apiIssue.PullRequest = &api.PullRequestMeta{
HasMerged: issue.PullRequest.HasMerged,
HasMerged: issue.PullRequest.HasMerged,
IsWorkInProgress: issue.PullRequest.IsWorkInProgress(ctx),
}
if issue.PullRequest.HasMerged {
apiIssue.PullRequest.Merged = issue.PullRequest.MergedUnix.AsTimePtr()
+3 -2
View File
@@ -61,8 +61,9 @@ func ToNotificationThread(ctx context.Context, n *activities_model.Notification)
result.Subject.LatestCommentHTMLURL = comment.HTMLURL(ctx)
}
pr, _ := n.Issue.GetPullRequest(ctx)
if pr != nil && pr.HasMerged {
if err := n.Issue.LoadPullRequest(ctx); err == nil &&
n.Issue.PullRequest != nil &&
n.Issue.PullRequest.HasMerged {
result.Subject.State = "merged"
}
}
+1
View File
@@ -35,6 +35,7 @@ func ToPackage(ctx context.Context, pd *packages.PackageDescriptor, doer *user_m
Name: pd.Package.Name,
Version: pd.Version.Version,
CreatedAt: pd.Version.CreatedUnix.AsTime(),
HTMLURL: pd.VersionHTMLURL(),
}, nil
}
+5 -4
View File
@@ -12,6 +12,7 @@ import (
access_model "code.gitea.io/gitea/models/perm/access"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/log"
api "code.gitea.io/gitea/modules/structs"
)
@@ -101,7 +102,7 @@ func ToAPIPullRequest(ctx context.Context, pr *issues_model.PullRequest, doer *u
apiPullRequest.Closed = pr.Issue.ClosedUnix.AsTimePtr()
}
gitRepo, err := git.OpenRepository(ctx, pr.BaseRepo.RepoPath())
gitRepo, err := gitrepo.OpenRepository(ctx, pr.BaseRepo)
if err != nil {
log.Error("OpenRepository[%s]: %v", pr.BaseRepo.RepoPath(), err)
return nil
@@ -127,7 +128,7 @@ func ToAPIPullRequest(ctx context.Context, pr *issues_model.PullRequest, doer *u
}
if pr.Flow == issues_model.PullRequestFlowAGit {
gitRepo, err := git.OpenRepository(ctx, pr.BaseRepo.RepoPath())
gitRepo, err := gitrepo.OpenRepository(ctx, pr.BaseRepo)
if err != nil {
log.Error("OpenRepository[%s]: %v", pr.GetGitRefName(), err)
return nil
@@ -154,7 +155,7 @@ func ToAPIPullRequest(ctx context.Context, pr *issues_model.PullRequest, doer *u
apiPullRequest.Head.RepoID = pr.HeadRepo.ID
apiPullRequest.Head.Repository = ToRepo(ctx, pr.HeadRepo, p)
headGitRepo, err := git.OpenRepository(ctx, pr.HeadRepo.RepoPath())
headGitRepo, err := gitrepo.OpenRepository(ctx, pr.HeadRepo)
if err != nil {
log.Error("OpenRepository[%s]: %v", pr.HeadRepo.RepoPath(), err)
return nil
@@ -190,7 +191,7 @@ func ToAPIPullRequest(ctx context.Context, pr *issues_model.PullRequest, doer *u
}
if len(apiPullRequest.Head.Sha) == 0 && len(apiPullRequest.Head.Ref) != 0 {
baseGitRepo, err := git.OpenRepository(ctx, pr.BaseRepo.RepoPath())
baseGitRepo, err := gitrepo.OpenRepository(ctx, pr.BaseRepo)
if err != nil {
log.Error("OpenRepository[%s]: %v", pr.BaseRepo.RepoPath(), err)
return nil
+9 -7
View File
@@ -21,15 +21,9 @@ func ToPullReview(ctx context.Context, r *issues_model.Review, doer *user_model.
r.Reviewer = user_model.NewGhostUser()
}
apiTeam, err := ToTeam(ctx, r.ReviewerTeam)
if err != nil {
return nil, err
}
result := &api.PullReview{
ID: r.ID,
Reviewer: ToUser(ctx, r.Reviewer, doer),
ReviewerTeam: apiTeam,
State: api.ReviewStateUnknown,
Body: r.Content,
CommitID: r.CommitID,
@@ -43,6 +37,14 @@ func ToPullReview(ctx context.Context, r *issues_model.Review, doer *user_model.
HTMLPullURL: r.Issue.HTMLURL(),
}
if r.ReviewerTeam != nil {
var err error
result.ReviewerTeam, err = ToTeam(ctx, r.ReviewerTeam)
if err != nil {
return nil, err
}
}
switch r.Type {
case issues_model.ReviewTypeApprove:
result.State = api.ReviewStateApproved
@@ -64,7 +66,7 @@ func ToPullReviewList(ctx context.Context, rl []*issues_model.Review, doer *user
result := make([]*api.PullReview, 0, len(rl))
for i := range rl {
// show pending reviews only for the user who created them
if rl[i].Type == issues_model.ReviewTypePending && !(doer.IsAdmin || doer.ID == rl[i].ReviewerID) {
if rl[i].Type == issues_model.ReviewTypePending && (doer == nil || (!doer.IsAdmin && doer.ID != rl[i].ReviewerID)) {
continue
}
r, err := ToPullReview(ctx, rl[i], doer)
+52
View File
@@ -0,0 +1,52 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package convert
import (
"testing"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"github.com/stretchr/testify/assert"
)
func Test_ToPullReview(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
reviewer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
review := unittest.AssertExistsAndLoadBean(t, &issues_model.Review{ID: 6})
assert.EqualValues(t, reviewer.ID, review.ReviewerID)
assert.EqualValues(t, issues_model.ReviewTypePending, review.Type)
reviewList := []*issues_model.Review{review}
t.Run("Anonymous User", func(t *testing.T) {
prList, err := ToPullReviewList(db.DefaultContext, reviewList, nil)
assert.NoError(t, err)
assert.Empty(t, prList)
})
t.Run("Reviewer Himself", func(t *testing.T) {
prList, err := ToPullReviewList(db.DefaultContext, reviewList, reviewer)
assert.NoError(t, err)
assert.Len(t, prList, 1)
})
t.Run("Other User", func(t *testing.T) {
user4 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 4})
prList, err := ToPullReviewList(db.DefaultContext, reviewList, user4)
assert.NoError(t, err)
assert.Len(t, prList, 0)
})
t.Run("Admin User", func(t *testing.T) {
adminUser := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
prList, err := ToPullReviewList(db.DefaultContext, reviewList, adminUser)
assert.NoError(t, err)
assert.Len(t, prList, 1)
})
}
+14 -2
View File
@@ -8,6 +8,7 @@ import (
"time"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
access_model "code.gitea.io/gitea/models/perm/access"
repo_model "code.gitea.io/gitea/models/repo"
@@ -92,6 +93,7 @@ func innerToRepo(ctx context.Context, repo *repo_model.Repository, permissionInR
allowRebase := false
allowRebaseMerge := false
allowSquash := false
allowFastForwardOnly := false
allowRebaseUpdate := false
defaultDeleteBranchAfterMerge := false
defaultMergeStyle := repo_model.MergeStyleMerge
@@ -104,14 +106,18 @@ func innerToRepo(ctx context.Context, repo *repo_model.Repository, permissionInR
allowRebase = config.AllowRebase
allowRebaseMerge = config.AllowRebaseMerge
allowSquash = config.AllowSquash
allowFastForwardOnly = config.AllowFastForwardOnly
allowRebaseUpdate = config.AllowRebaseUpdate
defaultDeleteBranchAfterMerge = config.DefaultDeleteBranchAfterMerge
defaultMergeStyle = config.GetDefaultMergeStyle()
defaultAllowMaintainerEdit = config.DefaultAllowMaintainerEdit
}
hasProjects := false
if _, err := repo.GetUnit(ctx, unit_model.TypeProjects); err == nil {
projectsMode := repo_model.ProjectsModeAll
if unit, err := repo.GetUnit(ctx, unit_model.TypeProjects); err == nil {
hasProjects = true
config := unit.ProjectsConfig()
projectsMode = config.ProjectsMode
}
hasReleases := false
@@ -133,7 +139,11 @@ func innerToRepo(ctx context.Context, repo *repo_model.Repository, permissionInR
return nil
}
numReleases, _ := repo_model.GetReleaseCountByRepoID(ctx, repo.ID, repo_model.FindReleasesOptions{IncludeDrafts: false, IncludeTags: false})
numReleases, _ := db.Count[repo_model.Release](ctx, repo_model.FindReleasesOptions{
IncludeDrafts: false,
IncludeTags: false,
RepoID: repo.ID,
})
mirrorInterval := ""
var mirrorUpdated time.Time
@@ -204,6 +214,7 @@ func innerToRepo(ctx context.Context, repo *repo_model.Repository, permissionInR
InternalTracker: internalTracker,
HasWiki: hasWiki,
HasProjects: hasProjects,
ProjectsMode: string(projectsMode),
HasReleases: hasReleases,
HasPackages: hasPackages,
HasActions: hasActions,
@@ -214,6 +225,7 @@ func innerToRepo(ctx context.Context, repo *repo_model.Repository, permissionInR
AllowRebase: allowRebase,
AllowRebaseMerge: allowRebaseMerge,
AllowSquash: allowSquash,
AllowFastForwardOnly: allowFastForwardOnly,
AllowRebaseUpdate: allowRebaseUpdate,
DefaultDeleteBranchAfterMerge: defaultDeleteBranchAfterMerge,
DefaultMergeStyle: string(defaultMergeStyle),
-15
View File
@@ -6,11 +6,8 @@ package convert
import (
"time"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/git"
api "code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/util"
wiki_service "code.gitea.io/gitea/services/wiki"
)
// ToWikiCommit convert a git commit into a WikiCommit
@@ -46,15 +43,3 @@ func ToWikiCommitList(commits []*git.Commit, total int64) *api.WikiCommitList {
Count: total,
}
}
// ToWikiPageMetaData converts meta information to a WikiPageMetaData
func ToWikiPageMetaData(wikiName wiki_service.WebPath, lastCommit *git.Commit, repo *repo_model.Repository) *api.WikiPageMetaData {
subURL := string(wikiName)
_, title := wiki_service.WebPathToUserTitle(wikiName)
return &api.WikiPageMetaData{
Title: title,
HTMLURL: util.URLJoin(repo.HTMLURL(), "wiki", subURL),
SubURL: subURL,
LastCommit: ToWikiCommit(lastCommit),
}
}
+3 -3
View File
@@ -70,7 +70,7 @@ func (b *BaseConfig) DoNoticeOnSuccess() bool {
// Please note the `status` string will be concatenated with `admin.dashboard.cron.` and `admin.dashboard.task.` to provide locale messages. Similarly `name` will be composed with `admin.dashboard.` to provide the locale name for the task.
func (b *BaseConfig) FormatMessage(locale translation.Locale, name, status, doer string, args ...any) string {
realArgs := make([]any, 0, len(args)+2)
realArgs = append(realArgs, locale.Tr("admin.dashboard."+name))
realArgs = append(realArgs, locale.TrString("admin.dashboard."+name))
if doer == "" {
realArgs = append(realArgs, "(Cron)")
} else {
@@ -80,7 +80,7 @@ func (b *BaseConfig) FormatMessage(locale translation.Locale, name, status, doer
realArgs = append(realArgs, args...)
}
if doer == "" {
return locale.Tr("admin.dashboard.cron."+status, realArgs...)
return locale.TrString("admin.dashboard.cron."+status, realArgs...)
}
return locale.Tr("admin.dashboard.task."+status, realArgs...)
return locale.TrString("admin.dashboard.task."+status, realArgs...)
}
+8 -6
View File
@@ -84,13 +84,15 @@ func (t *Task) RunWithUser(doer *user_model.User, config Config) {
t.lock.Unlock()
defer func() {
taskStatusTable.Stop(t.Name)
if err := recover(); err != nil {
// Recover a panic within the
combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2))
log.Error("PANIC whilst running task: %s Value: %v", t.Name, combinedErr)
}
}()
graceful.GetManager().RunWithShutdownContext(func(baseCtx context.Context) {
defer func() {
if err := recover(); err != nil {
// Recover a panic within the execution of the task.
combinedErr := fmt.Errorf("%s\n%s", err, log.Stack(2))
log.Error("PANIC whilst running task: %s Value: %v", t.Name, combinedErr)
}
}()
// Store the time of this run, before the function is executed, so it
// matches the behavior of what the cron library does.
t.lock.Lock()
@@ -157,7 +159,7 @@ func RegisterTask(name string, config Config, fun func(context.Context, *user_mo
log.Debug("Registering task: %s", name)
i18nKey := "admin.dashboard." + name
if value := translation.NewLocale("en-US").Tr(i18nKey); value == i18nKey {
if value := translation.NewLocale("en-US").TrString(i18nKey); value == i18nKey {
return fmt.Errorf("translation is missing for task %q, please add translation for %q", name, i18nKey)
}
+3 -3
View File
@@ -8,13 +8,13 @@ import (
"time"
activities_model "code.gitea.io/gitea/models/activities"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/models/system"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
issue_indexer "code.gitea.io/gitea/modules/indexer/issues"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/updatechecker"
asymkey_service "code.gitea.io/gitea/services/asymkey"
repo_service "code.gitea.io/gitea/services/repository"
archiver_service "code.gitea.io/gitea/services/repository/archiver"
user_service "code.gitea.io/gitea/services/user"
@@ -71,7 +71,7 @@ func registerRewriteAllPublicKeys() {
RunAtStart: false,
Schedule: "@every 72h",
}, func(ctx context.Context, _ *user_model.User, _ Config) error {
return asymkey_model.RewriteAllPublicKeys(ctx)
return asymkey_service.RewriteAllPublicKeys(ctx)
})
}
@@ -81,7 +81,7 @@ func registerRewriteAllPrincipalKeys() {
RunAtStart: false,
Schedule: "@every 72h",
}, func(ctx context.Context, _ *user_model.User, _ Config) error {
return asymkey_model.RewriteAllPrincipalKeys(ctx)
return asymkey_service.RewriteAllPrincipalKeys(ctx)
})
}
+12 -6
View File
@@ -4,6 +4,7 @@
package cron
import (
"sort"
"strconv"
"testing"
@@ -22,9 +23,10 @@ func TestAddTaskToScheduler(t *testing.T) {
},
})
assert.NoError(t, err)
assert.Len(t, scheduler.Jobs(), 1)
assert.Equal(t, "task 1", scheduler.Jobs()[0].Tags()[0])
assert.Equal(t, "5 4 * * *", scheduler.Jobs()[0].Tags()[1])
jobs := scheduler.Jobs()
assert.Len(t, jobs, 1)
assert.Equal(t, "task 1", jobs[0].Tags()[0])
assert.Equal(t, "5 4 * * *", jobs[0].Tags()[1])
// with seconds
err = addTaskToScheduler(&Task{
@@ -34,9 +36,13 @@ func TestAddTaskToScheduler(t *testing.T) {
},
})
assert.NoError(t, err)
assert.Len(t, scheduler.Jobs(), 2)
assert.Equal(t, "task 2", scheduler.Jobs()[1].Tags()[0])
assert.Equal(t, "30 5 4 * * *", scheduler.Jobs()[1].Tags()[1])
jobs = scheduler.Jobs() // the item order is not guaranteed, so we need to sort it before "assert"
sort.Slice(jobs, func(i, j int) bool {
return jobs[i].Tags()[0] < jobs[j].Tags()[0]
})
assert.Len(t, jobs, 2)
assert.Equal(t, "task 2", jobs[1].Tags()[0])
assert.Equal(t, "30 5 4 * * *", jobs[1].Tags()[1])
}
func TestScheduleHasSeconds(t *testing.T) {
+101
View File
@@ -0,0 +1,101 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"bufio"
"bytes"
"context"
"fmt"
"os"
"path/filepath"
"strings"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
asymkey_service "code.gitea.io/gitea/services/asymkey"
)
const tplCommentPrefix = `# gitea public key`
func checkAuthorizedKeys(ctx context.Context, logger log.Logger, autofix bool) error {
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedKeysFile {
return nil
}
fPath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
f, err := os.Open(fPath)
if err != nil {
if !autofix {
logger.Critical("Unable to open authorized_keys file. ERROR: %v", err)
return fmt.Errorf("Unable to open authorized_keys file. ERROR: %w", err)
}
logger.Warn("Unable to open authorized_keys. (ERROR: %v). Attempting to rewrite...", err)
if err = asymkey_service.RewriteAllPublicKeys(ctx); err != nil {
logger.Critical("Unable to rewrite authorized_keys file. ERROR: %v", err)
return fmt.Errorf("Unable to rewrite authorized_keys file. ERROR: %w", err)
}
}
defer f.Close()
linesInAuthorizedKeys := make(container.Set[string])
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, tplCommentPrefix) {
continue
}
linesInAuthorizedKeys.Add(line)
}
if err = scanner.Err(); err != nil {
return fmt.Errorf("scan: %w", err)
}
// although there is a "defer close" above, here close explicitly before the generating, because it needs to open the file for writing again
_ = f.Close()
// now we regenerate and check if there are any lines missing
regenerated := &bytes.Buffer{}
if err := asymkey_model.RegeneratePublicKeys(ctx, regenerated); err != nil {
logger.Critical("Unable to regenerate authorized_keys file. ERROR: %v", err)
return fmt.Errorf("Unable to regenerate authorized_keys file. ERROR: %w", err)
}
scanner = bufio.NewScanner(regenerated)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, tplCommentPrefix) {
continue
}
if linesInAuthorizedKeys.Contains(line) {
continue
}
if !autofix {
logger.Critical(
"authorized_keys file %q is out of date.\nRegenerate it with:\n\t\"%s\"\nor\n\t\"%s\"",
fPath,
"gitea admin regenerate keys",
"gitea doctor --run authorized-keys --fix")
return fmt.Errorf(`authorized_keys is out of date and should be regenerated with "gitea admin regenerate keys" or "gitea doctor --run authorized-keys --fix"`)
}
logger.Warn("authorized_keys is out of date. Attempting rewrite...")
err = asymkey_service.RewriteAllPublicKeys(ctx)
if err != nil {
logger.Critical("Unable to rewrite authorized_keys file. ERROR: %v", err)
return fmt.Errorf("Unable to rewrite authorized_keys file. ERROR: %w", err)
}
}
return nil
}
func init() {
Register(&Check{
Title: "Check if OpenSSH authorized_keys file is up-to-date",
Name: "authorized-keys",
IsDefault: true,
Run: checkAuthorizedKeys,
Priority: 4,
})
}
+97
View File
@@ -0,0 +1,97 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"xorm.io/builder"
)
func iterateUserAccounts(ctx context.Context, each func(*user.User) error) error {
err := db.Iterate(
ctx,
builder.Gt{"id": 0},
func(ctx context.Context, bean *user.User) error {
return each(bean)
},
)
return err
}
// Since 1.16.4 new restrictions has been set on email addresses. However users with invalid email
// addresses would be currently facing a error due to their invalid email address.
// Ref: https://github.com/go-gitea/gitea/pull/19085 & https://github.com/go-gitea/gitea/pull/17688
func checkUserEmail(ctx context.Context, logger log.Logger, _ bool) error {
// We could use quirky SQL to get all users that start without a [a-zA-Z0-9], but that would mean
// DB provider-specific SQL and only works _now_. So instead we iterate through all user accounts
// and use the user.ValidateEmail function to be future-proof.
var invalidUserCount int64
if err := iterateUserAccounts(ctx, func(u *user.User) error {
// Only check for users, skip
if u.Type != user.UserTypeIndividual {
return nil
}
if err := user.ValidateEmail(u.Email); err != nil {
invalidUserCount++
logger.Warn("User[id=%d name=%q] have not a valid e-mail: %v", u.ID, u.Name, err)
}
return nil
}); err != nil {
return fmt.Errorf("iterateUserAccounts: %w", err)
}
if invalidUserCount == 0 {
logger.Info("All users have a valid e-mail.")
} else {
logger.Warn("%d user(s) have a non-valid e-mail.", invalidUserCount)
}
return nil
}
// From time to time Gitea makes changes to the reserved usernames and which symbols
// are allowed for various reasons. This check helps with detecting users that, according
// to our reserved names, don't have a valid username.
func checkUserName(ctx context.Context, logger log.Logger, _ bool) error {
var invalidUserCount int64
if err := iterateUserAccounts(ctx, func(u *user.User) error {
if err := user.IsUsableUsername(u.Name); err != nil {
invalidUserCount++
logger.Warn("User[id=%d] does not have a valid username: %v", u.ID, err)
}
return nil
}); err != nil {
return fmt.Errorf("iterateUserAccounts: %w", err)
}
if invalidUserCount == 0 {
logger.Info("All users have a valid username.")
} else {
logger.Warn("%d user(s) have a non-valid username.", invalidUserCount)
}
return nil
}
func init() {
Register(&Check{
Title: "Check if users has an valid email address",
Name: "check-user-email",
IsDefault: false,
Run: checkUserEmail,
Priority: 9,
})
Register(&Check{
Title: "Check if users have a valid username",
Name: "check-user-names",
IsDefault: false,
Run: checkUserName,
Priority: 9,
})
}
+59
View File
@@ -0,0 +1,59 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"os"
"path/filepath"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/util"
)
func checkOldArchives(ctx context.Context, logger log.Logger, autofix bool) error {
numRepos := 0
numReposUpdated := 0
err := iterateRepositories(ctx, func(repo *repo_model.Repository) error {
if repo.IsEmpty {
return nil
}
p := filepath.Join(repo.RepoPath(), "archives")
isDir, err := util.IsDir(p)
if err != nil {
log.Warn("check if %s is directory failed: %v", p, err)
}
if isDir {
numRepos++
if autofix {
if err := os.RemoveAll(p); err == nil {
numReposUpdated++
} else {
log.Warn("remove %s failed: %v", p, err)
}
}
}
return nil
})
if autofix {
logger.Info("%d / %d old archives in repository deleted", numReposUpdated, numRepos)
} else {
logger.Info("%d old archives in repository need to be deleted", numRepos)
}
return err
}
func init() {
Register(&Check{
Title: "Check old archives",
Name: "check-old-archives",
IsDefault: false,
Run: checkOldArchives,
Priority: 7,
})
}
+245
View File
@@ -0,0 +1,245 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
actions_model "code.gitea.io/gitea/models/actions"
activities_model "code.gitea.io/gitea/models/activities"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/migrations"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
)
type consistencyCheck struct {
Name string
Counter func(context.Context) (int64, error)
Fixer func(context.Context) (int64, error)
FixedMessage string
}
func (c *consistencyCheck) Run(ctx context.Context, logger log.Logger, autofix bool) error {
count, err := c.Counter(ctx)
if err != nil {
logger.Critical("Error: %v whilst counting %s", err, c.Name)
return err
}
if count > 0 {
if autofix {
var fixed int64
if fixed, err = c.Fixer(ctx); err != nil {
logger.Critical("Error: %v whilst fixing %s", err, c.Name)
return err
}
prompt := "Deleted"
if c.FixedMessage != "" {
prompt = c.FixedMessage
}
if fixed < 0 {
logger.Info(prompt+" %d %s", count, c.Name)
} else {
logger.Info(prompt+" %d/%d %s", fixed, count, c.Name)
}
} else {
logger.Warn("Found %d %s", count, c.Name)
}
}
return nil
}
func asFixer(fn func(ctx context.Context) error) func(ctx context.Context) (int64, error) {
return func(ctx context.Context) (int64, error) {
err := fn(ctx)
return -1, err
}
}
func genericOrphanCheck(name, subject, refobject, joincond string) consistencyCheck {
return consistencyCheck{
Name: name,
Counter: func(ctx context.Context) (int64, error) {
return db.CountOrphanedObjects(ctx, subject, refobject, joincond)
},
Fixer: func(ctx context.Context) (int64, error) {
err := db.DeleteOrphanedObjects(ctx, subject, refobject, joincond)
return -1, err
},
}
}
func checkDBConsistency(ctx context.Context, logger log.Logger, autofix bool) error {
// make sure DB version is uptodate
if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil {
logger.Critical("Model version on the database does not match the current Gitea version. Model consistency will not be checked until the database is upgraded")
return err
}
consistencyChecks := []consistencyCheck{
{
// find labels without existing repo or org
Name: "Orphaned Labels without existing repository or organisation",
Counter: issues_model.CountOrphanedLabels,
Fixer: asFixer(issues_model.DeleteOrphanedLabels),
},
{
// find IssueLabels without existing label
Name: "Orphaned Issue Labels without existing label",
Counter: issues_model.CountOrphanedIssueLabels,
Fixer: asFixer(issues_model.DeleteOrphanedIssueLabels),
},
{
// find issues without existing repository
Name: "Orphaned Issues without existing repository",
Counter: issues_model.CountOrphanedIssues,
Fixer: asFixer(issues_model.DeleteOrphanedIssues),
},
// find releases without existing repository
genericOrphanCheck("Orphaned Releases without existing repository",
"release", "repository", "`release`.repo_id=repository.id"),
// find pulls without existing issues
genericOrphanCheck("Orphaned PullRequests without existing issue",
"pull_request", "issue", "pull_request.issue_id=issue.id"),
// find pull requests without base repository
genericOrphanCheck("Pull request entries without existing base repository",
"pull_request", "repository", "pull_request.base_repo_id=repository.id"),
// find tracked times without existing issues/pulls
genericOrphanCheck("Orphaned TrackedTimes without existing issue",
"tracked_time", "issue", "tracked_time.issue_id=issue.id"),
// find attachments without existing issues or releases
{
Name: "Orphaned Attachments without existing issues or releases",
Counter: repo_model.CountOrphanedAttachments,
Fixer: asFixer(repo_model.DeleteOrphanedAttachments),
},
// find null archived repositories
{
Name: "Repositories with is_archived IS NULL",
Counter: repo_model.CountNullArchivedRepository,
Fixer: repo_model.FixNullArchivedRepository,
FixedMessage: "Fixed",
},
// find label comments with empty labels
{
Name: "Label comments with empty labels",
Counter: issues_model.CountCommentTypeLabelWithEmptyLabel,
Fixer: issues_model.FixCommentTypeLabelWithEmptyLabel,
FixedMessage: "Fixed",
},
// find label comments with labels from outside the repository
{
Name: "Label comments with labels from outside the repository",
Counter: issues_model.CountCommentTypeLabelWithOutsideLabels,
Fixer: issues_model.FixCommentTypeLabelWithOutsideLabels,
FixedMessage: "Removed",
},
// find issue_label with labels from outside the repository
{
Name: "IssueLabels with Labels from outside the repository",
Counter: issues_model.CountIssueLabelWithOutsideLabels,
Fixer: issues_model.FixIssueLabelWithOutsideLabels,
FixedMessage: "Removed",
},
{
Name: "Action with created_unix set as an empty string",
Counter: activities_model.CountActionCreatedUnixString,
Fixer: activities_model.FixActionCreatedUnixString,
FixedMessage: "Set to zero",
},
{
Name: "Action Runners without existing owner",
Counter: actions_model.CountRunnersWithoutBelongingOwner,
Fixer: actions_model.FixRunnersWithoutBelongingOwner,
FixedMessage: "Removed",
},
{
Name: "Topics with empty repository count",
Counter: repo_model.CountOrphanedTopics,
Fixer: repo_model.DeleteOrphanedTopics,
FixedMessage: "Removed",
},
}
// TODO: function to recalc all counters
if setting.Database.Type.IsPostgreSQL() {
consistencyChecks = append(consistencyChecks, consistencyCheck{
Name: "Sequence values",
Counter: db.CountBadSequences,
Fixer: asFixer(db.FixBadSequences),
FixedMessage: "Updated",
})
}
consistencyChecks = append(consistencyChecks,
// find protected branches without existing repository
genericOrphanCheck("Protected Branches without existing repository",
"protected_branch", "repository", "protected_branch.repo_id=repository.id"),
// find branches without existing repository
genericOrphanCheck("Branches without existing repository",
"branch", "repository", "branch.repo_id=repository.id"),
// find LFS locks without existing repository
genericOrphanCheck("LFS locks without existing repository",
"lfs_lock", "repository", "lfs_lock.repo_id=repository.id"),
// find collaborations without users
genericOrphanCheck("Collaborations without existing user",
"collaboration", "user", "collaboration.user_id=`user`.id"),
// find collaborations without repository
genericOrphanCheck("Collaborations without existing repository",
"collaboration", "repository", "collaboration.repo_id=repository.id"),
// find access without users
genericOrphanCheck("Access entries without existing user",
"access", "user", "access.user_id=`user`.id"),
// find access without repository
genericOrphanCheck("Access entries without existing repository",
"access", "repository", "access.repo_id=repository.id"),
// find action without repository
genericOrphanCheck("Action entries without existing repository",
"action", "repository", "action.repo_id=repository.id"),
// find action without user
genericOrphanCheck("Action entries without existing user",
"action", "user", "action.act_user_id=`user`.id"),
// find OAuth2Grant without existing user
genericOrphanCheck("Orphaned OAuth2Grant without existing User",
"oauth2_grant", "user", "oauth2_grant.user_id=`user`.id"),
// find OAuth2Application without existing user
genericOrphanCheck("Orphaned OAuth2Application without existing User",
"oauth2_application", "user", "oauth2_application.uid=`user`.id"),
// find OAuth2AuthorizationCode without existing OAuth2Grant
genericOrphanCheck("Orphaned OAuth2AuthorizationCode without existing OAuth2Grant",
"oauth2_authorization_code", "oauth2_grant", "oauth2_authorization_code.grant_id=oauth2_grant.id"),
// find stopwatches without existing user
genericOrphanCheck("Orphaned Stopwatches without existing User",
"stopwatch", "user", "stopwatch.user_id=`user`.id"),
// find stopwatches without existing issue
genericOrphanCheck("Orphaned Stopwatches without existing Issue",
"stopwatch", "issue", "stopwatch.issue_id=`issue`.id"),
// find redirects without existing user.
genericOrphanCheck("Orphaned Redirects without existing redirect user",
"user_redirect", "user", "user_redirect.redirect_user_id=`user`.id"),
)
for _, c := range consistencyChecks {
if err := c.Run(ctx, logger, autofix); err != nil {
return err
}
}
return nil
}
func init() {
Register(&Check{
Title: "Check consistency of database",
Name: "check-db-consistency",
IsDefault: false,
Run: checkDBConsistency,
Priority: 3,
})
}
+42
View File
@@ -0,0 +1,42 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/migrations"
"code.gitea.io/gitea/modules/log"
)
func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error {
logger.Info("Expected database version: %d", migrations.ExpectedVersion())
if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil {
if !autofix {
logger.Critical("Error: %v during ensure up to date", err)
return err
}
logger.Warn("Got Error: %v during ensure up to date", err)
logger.Warn("Attempting to migrate to the latest DB version to fix this.")
err = db.InitEngineWithMigration(ctx, migrations.Migrate)
if err != nil {
logger.Critical("Error: %v during migration", err)
}
return err
}
return nil
}
func init() {
Register(&Check{
Title: "Check Database Version",
Name: "check-db-version",
IsDefault: true,
Run: checkDBVersion,
AbortIfFailed: false,
Priority: 2,
})
}
+127
View File
@@ -0,0 +1,127 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"fmt"
"os"
"sort"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
)
// Check represents a Doctor check
type Check struct {
Title string
Name string
IsDefault bool
Run func(ctx context.Context, logger log.Logger, autofix bool) error
AbortIfFailed bool
SkipDatabaseInitialization bool
Priority int
}
func initDBSkipLogger(ctx context.Context) error {
setting.MustInstalled()
setting.LoadDBSetting()
if err := db.InitEngine(ctx); err != nil {
return fmt.Errorf("db.InitEngine: %w", err)
}
// some doctor sub-commands need to use git command
if err := git.InitFull(ctx); err != nil {
return fmt.Errorf("git.InitFull: %w", err)
}
return nil
}
type doctorCheckLogger struct {
colorize bool
}
var _ log.BaseLogger = (*doctorCheckLogger)(nil)
func (d *doctorCheckLogger) Log(skip int, level log.Level, format string, v ...any) {
_, _ = fmt.Fprintf(os.Stdout, format+"\n", v...)
}
func (d *doctorCheckLogger) GetLevel() log.Level {
return log.TRACE
}
type doctorCheckStepLogger struct {
colorize bool
}
var _ log.BaseLogger = (*doctorCheckStepLogger)(nil)
func (d *doctorCheckStepLogger) Log(skip int, level log.Level, format string, v ...any) {
levelChar := fmt.Sprintf("[%s]", strings.ToUpper(level.String()[0:1]))
var levelArg any = levelChar
if d.colorize {
levelArg = log.NewColoredValue(levelChar, level.ColorAttributes()...)
}
args := append([]any{levelArg}, v...)
_, _ = fmt.Fprintf(os.Stdout, " - %s "+format+"\n", args...)
}
func (d *doctorCheckStepLogger) GetLevel() log.Level {
return log.TRACE
}
// Checks is the list of available commands
var Checks []*Check
// RunChecks runs the doctor checks for the provided list
func RunChecks(ctx context.Context, colorize, autofix bool, checks []*Check) error {
SortChecks(checks)
// the checks output logs by a special logger, they do not use the default logger
logger := log.BaseLoggerToGeneralLogger(&doctorCheckLogger{colorize: colorize})
loggerStep := log.BaseLoggerToGeneralLogger(&doctorCheckStepLogger{colorize: colorize})
dbIsInit := false
for i, check := range checks {
if !dbIsInit && !check.SkipDatabaseInitialization {
// Only open database after the most basic configuration check
if err := initDBSkipLogger(ctx); err != nil {
logger.Error("Error whilst initializing the database: %v", err)
logger.Error("Check if you are using the right config file. You can use a --config directive to specify one.")
return nil
}
dbIsInit = true
}
logger.Info("\n[%d] %s", i+1, check.Title)
if err := check.Run(ctx, loggerStep, autofix); err != nil {
if check.AbortIfFailed {
logger.Critical("FAIL")
return err
}
logger.Error("ERROR")
} else {
logger.Info("OK")
}
}
logger.Info("\nAll done (checks: %d).", len(checks))
return nil
}
// Register registers a command with the list
func Register(command *Check) {
Checks = append(Checks, command)
}
func SortChecks(checks []*Check) {
sort.SliceStable(checks, func(i, j int) bool {
if checks[i].Priority == checks[j].Priority {
return checks[i].Name < checks[j].Name
}
if checks[i].Priority == 0 {
return false
}
return checks[i].Priority < checks[j].Priority
})
}
+328
View File
@@ -0,0 +1,328 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"bytes"
"context"
"errors"
"fmt"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
)
// #16831 revealed that the dump command that was broken in 1.14.3-1.14.6 and 1.15.0 (#15885).
// This led to repo_unit and login_source cfg not being converted to JSON in the dump
// Unfortunately although it was hoped that there were only a few users affected it
// appears that many users are affected.
// We therefore need to provide a doctor command to fix this repeated issue #16961
func parseBool16961(bs []byte) (bool, error) {
if bytes.EqualFold(bs, []byte("%!s(bool=false)")) {
return false, nil
}
if bytes.EqualFold(bs, []byte("%!s(bool=true)")) {
return true, nil
}
return false, fmt.Errorf("unexpected bool format: %s", string(bs))
}
func fixUnitConfig16961(bs []byte, cfg *repo_model.UnitConfig) (fixed bool, err error) {
err = json.UnmarshalHandleDoubleEncode(bs, &cfg)
if err == nil {
return false, nil
}
// Handle #16961
if string(bs) != "&{}" && len(bs) != 0 {
return false, err
}
return true, nil
}
func fixExternalWikiConfig16961(bs []byte, cfg *repo_model.ExternalWikiConfig) (fixed bool, err error) {
err = json.UnmarshalHandleDoubleEncode(bs, &cfg)
if err == nil {
return false, nil
}
if len(bs) < 3 {
return false, err
}
if bs[0] != '&' || bs[1] != '{' || bs[len(bs)-1] != '}' {
return false, err
}
cfg.ExternalWikiURL = string(bs[2 : len(bs)-1])
return true, nil
}
func fixExternalTrackerConfig16961(bs []byte, cfg *repo_model.ExternalTrackerConfig) (fixed bool, err error) {
err = json.UnmarshalHandleDoubleEncode(bs, &cfg)
if err == nil {
return false, nil
}
// Handle #16961
if len(bs) < 3 {
return false, err
}
if bs[0] != '&' || bs[1] != '{' || bs[len(bs)-1] != '}' {
return false, err
}
parts := bytes.Split(bs[2:len(bs)-1], []byte{' '})
if len(parts) != 3 {
return false, err
}
cfg.ExternalTrackerURL = string(bytes.Join(parts[:len(parts)-2], []byte{' '}))
cfg.ExternalTrackerFormat = string(parts[len(parts)-2])
cfg.ExternalTrackerStyle = string(parts[len(parts)-1])
return true, nil
}
func fixPullRequestsConfig16961(bs []byte, cfg *repo_model.PullRequestsConfig) (fixed bool, err error) {
err = json.UnmarshalHandleDoubleEncode(bs, &cfg)
if err == nil {
return false, nil
}
// Handle #16961
if len(bs) < 3 {
return false, err
}
if bs[0] != '&' || bs[1] != '{' || bs[len(bs)-1] != '}' {
return false, err
}
// PullRequestsConfig was the following in 1.14
// type PullRequestsConfig struct {
// IgnoreWhitespaceConflicts bool
// AllowMerge bool
// AllowRebase bool
// AllowRebaseMerge bool
// AllowSquash bool
// AllowManualMerge bool
// AutodetectManualMerge bool
// }
//
// 1.15 added in addition:
// DefaultDeleteBranchAfterMerge bool
// DefaultMergeStyle MergeStyle
parts := bytes.Split(bs[2:len(bs)-1], []byte{' '})
if len(parts) < 7 {
return false, err
}
var parseErr error
cfg.IgnoreWhitespaceConflicts, parseErr = parseBool16961(parts[0])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AllowMerge, parseErr = parseBool16961(parts[1])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AllowRebase, parseErr = parseBool16961(parts[2])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AllowRebaseMerge, parseErr = parseBool16961(parts[3])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AllowSquash, parseErr = parseBool16961(parts[4])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AllowManualMerge, parseErr = parseBool16961(parts[5])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AutodetectManualMerge, parseErr = parseBool16961(parts[6])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
// 1.14 unit
if len(parts) == 7 {
return true, nil
}
if len(parts) < 9 {
return false, err
}
cfg.DefaultDeleteBranchAfterMerge, parseErr = parseBool16961(parts[7])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.DefaultMergeStyle = repo_model.MergeStyle(string(bytes.Join(parts[8:], []byte{' '})))
return true, nil
}
func fixIssuesConfig16961(bs []byte, cfg *repo_model.IssuesConfig) (fixed bool, err error) {
err = json.UnmarshalHandleDoubleEncode(bs, &cfg)
if err == nil {
return false, nil
}
// Handle #16961
if len(bs) < 3 {
return false, err
}
if bs[0] != '&' || bs[1] != '{' || bs[len(bs)-1] != '}' {
return false, err
}
parts := bytes.Split(bs[2:len(bs)-1], []byte{' '})
if len(parts) != 3 {
return false, err
}
var parseErr error
cfg.EnableTimetracker, parseErr = parseBool16961(parts[0])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.AllowOnlyContributorsToTrackTime, parseErr = parseBool16961(parts[1])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
cfg.EnableDependencies, parseErr = parseBool16961(parts[2])
if parseErr != nil {
return false, errors.Join(err, parseErr)
}
return true, nil
}
func fixBrokenRepoUnit16961(repoUnit *repo_model.RepoUnit, bs []byte) (fixed bool, err error) {
// Shortcut empty or null values
if len(bs) == 0 {
return false, nil
}
var cfg any
err = json.UnmarshalHandleDoubleEncode(bs, &cfg)
if err == nil {
return false, nil
}
switch repoUnit.Type {
case unit.TypeCode, unit.TypeReleases, unit.TypeWiki, unit.TypeProjects:
cfg := &repo_model.UnitConfig{}
repoUnit.Config = cfg
if fixed, err := fixUnitConfig16961(bs, cfg); !fixed {
return false, err
}
case unit.TypeExternalWiki:
cfg := &repo_model.ExternalWikiConfig{}
repoUnit.Config = cfg
if fixed, err := fixExternalWikiConfig16961(bs, cfg); !fixed {
return false, err
}
case unit.TypeExternalTracker:
cfg := &repo_model.ExternalTrackerConfig{}
repoUnit.Config = cfg
if fixed, err := fixExternalTrackerConfig16961(bs, cfg); !fixed {
return false, err
}
case unit.TypePullRequests:
cfg := &repo_model.PullRequestsConfig{}
repoUnit.Config = cfg
if fixed, err := fixPullRequestsConfig16961(bs, cfg); !fixed {
return false, err
}
case unit.TypeIssues:
cfg := &repo_model.IssuesConfig{}
repoUnit.Config = cfg
if fixed, err := fixIssuesConfig16961(bs, cfg); !fixed {
return false, err
}
default:
panic(fmt.Sprintf("unrecognized repo unit type: %v", repoUnit.Type))
}
return true, nil
}
func fixBrokenRepoUnits16961(ctx context.Context, logger log.Logger, autofix bool) error {
// RepoUnit describes all units of a repository
type RepoUnit struct {
ID int64
RepoID int64
Type unit.Type
Config []byte
CreatedUnix timeutil.TimeStamp `xorm:"INDEX CREATED"`
}
count := 0
err := db.Iterate(
ctx,
builder.Gt{
"id": 0,
},
func(ctx context.Context, unit *RepoUnit) error {
bs := unit.Config
repoUnit := &repo_model.RepoUnit{
ID: unit.ID,
RepoID: unit.RepoID,
Type: unit.Type,
CreatedUnix: unit.CreatedUnix,
}
if fixed, err := fixBrokenRepoUnit16961(repoUnit, bs); !fixed {
return err
}
count++
if !autofix {
return nil
}
return repo_model.UpdateRepoUnit(ctx, repoUnit)
},
)
if err != nil {
logger.Critical("Unable to iterate across repounits to fix the broken units: Error %v", err)
return err
}
if !autofix {
if count == 0 {
logger.Info("Found no broken repo_units")
} else {
logger.Warn("Found %d broken repo_units", count)
}
return nil
}
logger.Info("Fixed %d broken repo_units", count)
return nil
}
func init() {
Register(&Check{
Title: "Check for incorrectly dumped repo_units (See #16961)",
Name: "fix-broken-repo-units",
IsDefault: false,
Run: fixBrokenRepoUnits16961,
Priority: 7,
})
}
+271
View File
@@ -0,0 +1,271 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"testing"
repo_model "code.gitea.io/gitea/models/repo"
"github.com/stretchr/testify/assert"
)
func Test_fixUnitConfig_16961(t *testing.T) {
tests := []struct {
name string
bs string
wantFixed bool
wantErr bool
}{
{
name: "empty",
bs: "",
wantFixed: true,
wantErr: false,
},
{
name: "normal: {}",
bs: "{}",
wantFixed: false,
wantErr: false,
},
{
name: "broken but fixable: &{}",
bs: "&{}",
wantFixed: true,
wantErr: false,
},
{
name: "broken but unfixable: &{asdasd}",
bs: "&{asdasd}",
wantFixed: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFixed, err := fixUnitConfig16961([]byte(tt.bs), &repo_model.UnitConfig{})
if (err != nil) != tt.wantErr {
t.Errorf("fixUnitConfig_16961() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFixed != tt.wantFixed {
t.Errorf("fixUnitConfig_16961() = %v, want %v", gotFixed, tt.wantFixed)
}
})
}
}
func Test_fixExternalWikiConfig_16961(t *testing.T) {
tests := []struct {
name string
bs string
expected string
wantFixed bool
wantErr bool
}{
{
name: "normal: {\"ExternalWikiURL\":\"http://someurl\"}",
bs: "{\"ExternalWikiURL\":\"http://someurl\"}",
expected: "http://someurl",
wantFixed: false,
wantErr: false,
},
{
name: "broken: &{http://someurl}",
bs: "&{http://someurl}",
expected: "http://someurl",
wantFixed: true,
wantErr: false,
},
{
name: "broken but unfixable: http://someurl",
bs: "http://someurl",
wantFixed: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &repo_model.ExternalWikiConfig{}
gotFixed, err := fixExternalWikiConfig16961([]byte(tt.bs), cfg)
if (err != nil) != tt.wantErr {
t.Errorf("fixExternalWikiConfig_16961() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFixed != tt.wantFixed {
t.Errorf("fixExternalWikiConfig_16961() = %v, want %v", gotFixed, tt.wantFixed)
}
if cfg.ExternalWikiURL != tt.expected {
t.Errorf("fixExternalWikiConfig_16961().ExternalWikiURL = %v, want %v", cfg.ExternalWikiURL, tt.expected)
}
})
}
}
func Test_fixExternalTrackerConfig_16961(t *testing.T) {
tests := []struct {
name string
bs string
expected repo_model.ExternalTrackerConfig
wantFixed bool
wantErr bool
}{
{
name: "normal",
bs: `{"ExternalTrackerURL":"a","ExternalTrackerFormat":"b","ExternalTrackerStyle":"c"}`,
expected: repo_model.ExternalTrackerConfig{
ExternalTrackerURL: "a",
ExternalTrackerFormat: "b",
ExternalTrackerStyle: "c",
},
wantFixed: false,
wantErr: false,
},
{
name: "broken",
bs: "&{a b c}",
expected: repo_model.ExternalTrackerConfig{
ExternalTrackerURL: "a",
ExternalTrackerFormat: "b",
ExternalTrackerStyle: "c",
},
wantFixed: true,
wantErr: false,
},
{
name: "broken - too many fields",
bs: "&{a b c d}",
wantFixed: false,
wantErr: true,
},
{
name: "broken - wrong format",
bs: "a b c d}",
wantFixed: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &repo_model.ExternalTrackerConfig{}
gotFixed, err := fixExternalTrackerConfig16961([]byte(tt.bs), cfg)
if (err != nil) != tt.wantErr {
t.Errorf("fixExternalTrackerConfig_16961() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFixed != tt.wantFixed {
t.Errorf("fixExternalTrackerConfig_16961() = %v, want %v", gotFixed, tt.wantFixed)
}
if cfg.ExternalTrackerFormat != tt.expected.ExternalTrackerFormat {
t.Errorf("fixExternalTrackerConfig_16961().ExternalTrackerFormat = %v, want %v", tt.expected.ExternalTrackerFormat, cfg.ExternalTrackerFormat)
}
if cfg.ExternalTrackerStyle != tt.expected.ExternalTrackerStyle {
t.Errorf("fixExternalTrackerConfig_16961().ExternalTrackerStyle = %v, want %v", tt.expected.ExternalTrackerStyle, cfg.ExternalTrackerStyle)
}
if cfg.ExternalTrackerURL != tt.expected.ExternalTrackerURL {
t.Errorf("fixExternalTrackerConfig_16961().ExternalTrackerURL = %v, want %v", tt.expected.ExternalTrackerURL, cfg.ExternalTrackerURL)
}
})
}
}
func Test_fixPullRequestsConfig_16961(t *testing.T) {
tests := []struct {
name string
bs string
expected repo_model.PullRequestsConfig
wantFixed bool
wantErr bool
}{
{
name: "normal",
bs: `{"IgnoreWhitespaceConflicts":false,"AllowMerge":false,"AllowRebase":false,"AllowRebaseMerge":false,"AllowSquash":false,"AllowManualMerge":false,"AutodetectManualMerge":false,"DefaultDeleteBranchAfterMerge":false,"DefaultMergeStyle":""}`,
},
{
name: "broken - 1.14",
bs: `&{%!s(bool=false) %!s(bool=true) %!s(bool=true) %!s(bool=true) %!s(bool=true) %!s(bool=false) %!s(bool=false)}`,
expected: repo_model.PullRequestsConfig{
IgnoreWhitespaceConflicts: false,
AllowMerge: true,
AllowRebase: true,
AllowRebaseMerge: true,
AllowSquash: true,
AllowManualMerge: false,
AutodetectManualMerge: false,
},
wantFixed: true,
},
{
name: "broken - 1.15",
bs: `&{%!s(bool=false) %!s(bool=true) %!s(bool=true) %!s(bool=true) %!s(bool=true) %!s(bool=false) %!s(bool=false) %!s(bool=false) merge}`,
expected: repo_model.PullRequestsConfig{
AllowMerge: true,
AllowRebase: true,
AllowRebaseMerge: true,
AllowSquash: true,
DefaultMergeStyle: repo_model.MergeStyleMerge,
},
wantFixed: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &repo_model.PullRequestsConfig{}
gotFixed, err := fixPullRequestsConfig16961([]byte(tt.bs), cfg)
if (err != nil) != tt.wantErr {
t.Errorf("fixPullRequestsConfig_16961() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFixed != tt.wantFixed {
t.Errorf("fixPullRequestsConfig_16961() = %v, want %v", gotFixed, tt.wantFixed)
}
assert.EqualValues(t, &tt.expected, cfg)
})
}
}
func Test_fixIssuesConfig_16961(t *testing.T) {
tests := []struct {
name string
bs string
expected repo_model.IssuesConfig
wantFixed bool
wantErr bool
}{
{
name: "normal",
bs: `{"EnableTimetracker":true,"AllowOnlyContributorsToTrackTime":true,"EnableDependencies":true}`,
expected: repo_model.IssuesConfig{
EnableTimetracker: true,
AllowOnlyContributorsToTrackTime: true,
EnableDependencies: true,
},
},
{
name: "broken",
bs: `&{%!s(bool=true) %!s(bool=true) %!s(bool=true)}`,
expected: repo_model.IssuesConfig{
EnableTimetracker: true,
AllowOnlyContributorsToTrackTime: true,
EnableDependencies: true,
},
wantFixed: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &repo_model.IssuesConfig{}
gotFixed, err := fixIssuesConfig16961([]byte(tt.bs), cfg)
if (err != nil) != tt.wantErr {
t.Errorf("fixIssuesConfig_16961() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFixed != tt.wantFixed {
t.Errorf("fixIssuesConfig_16961() = %v, want %v", gotFixed, tt.wantFixed)
}
assert.EqualValues(t, &tt.expected, cfg)
})
}
}
+61
View File
@@ -0,0 +1,61 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
org_model "code.gitea.io/gitea/models/organization"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/modules/log"
"xorm.io/builder"
)
func fixOwnerTeamCreateOrgRepo(ctx context.Context, logger log.Logger, autofix bool) error {
count := 0
err := db.Iterate(
ctx,
builder.Eq{"authorize": perm.AccessModeOwner, "can_create_org_repo": false},
func(ctx context.Context, team *org_model.Team) error {
team.CanCreateOrgRepo = true
count++
if !autofix {
return nil
}
return models.UpdateTeam(ctx, team, false, false)
},
)
if err != nil {
logger.Critical("Unable to iterate across repounits to fix incorrect can_create_org_repo: Error %v", err)
return err
}
if !autofix {
if count == 0 {
logger.Info("Found no team with incorrect can_create_org_repo")
} else {
logger.Warn("Found %d teams with incorrect can_create_org_repo", count)
}
return nil
}
logger.Info("Fixed %d teams with incorrect can_create_org_repo", count)
return nil
}
func init() {
Register(&Check{
Title: "Check for incorrect can_create_org_repo for org owner teams",
Name: "fix-owner-team-create-org-repo",
IsDefault: false,
Run: fixOwnerTeamCreateOrgRepo,
Priority: 7,
})
}
+88
View File
@@ -0,0 +1,88 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
)
func synchronizeRepoHeads(ctx context.Context, logger log.Logger, autofix bool) error {
numRepos := 0
numHeadsBroken := 0
numDefaultBranchesBroken := 0
numReposUpdated := 0
err := iterateRepositories(ctx, func(repo *repo_model.Repository) error {
numRepos++
_, _, defaultBranchErr := git.NewCommand(ctx, "rev-parse").AddDashesAndList(repo.DefaultBranch).RunStdString(&git.RunOpts{Dir: repo.RepoPath()})
head, _, headErr := git.NewCommand(ctx, "symbolic-ref", "--short", "HEAD").RunStdString(&git.RunOpts{Dir: repo.RepoPath()})
// what we expect: default branch is valid, and HEAD points to it
if headErr == nil && defaultBranchErr == nil && head == repo.DefaultBranch {
return nil
}
if headErr != nil {
numHeadsBroken++
}
if defaultBranchErr != nil {
numDefaultBranchesBroken++
}
// if default branch is broken, let the user fix that in the UI
if defaultBranchErr != nil {
logger.Warn("Default branch for %s/%s doesn't point to a valid commit", repo.OwnerName, repo.Name)
return nil
}
// if we're not autofixing, that's all we can do
if !autofix {
return nil
}
// otherwise, let's try fixing HEAD
err := git.NewCommand(ctx, "symbolic-ref").AddDashesAndList("HEAD", git.BranchPrefix+repo.DefaultBranch).Run(&git.RunOpts{Dir: repo.RepoPath()})
if err != nil {
logger.Warn("Failed to fix HEAD for %s/%s: %v", repo.OwnerName, repo.Name, err)
return nil
}
numReposUpdated++
return nil
})
if err != nil {
logger.Critical("Error when fixing repo HEADs: %v", err)
}
if autofix {
logger.Info("Out of %d repos, HEADs for %d are now fixed and HEADS for %d are still broken", numRepos, numReposUpdated, numDefaultBranchesBroken+numHeadsBroken-numReposUpdated)
} else {
if numHeadsBroken == 0 && numDefaultBranchesBroken == 0 {
logger.Info("All %d repos have their HEADs in the correct state", numRepos)
} else {
if numHeadsBroken == 0 && numDefaultBranchesBroken != 0 {
logger.Critical("Default branches are broken for %d/%d repos", numDefaultBranchesBroken, numRepos)
} else if numHeadsBroken != 0 && numDefaultBranchesBroken == 0 {
logger.Warn("HEADs are broken for %d/%d repos", numHeadsBroken, numRepos)
} else {
logger.Critical("Out of %d repos, HEADS are broken for %d and default branches are broken for %d", numRepos, numHeadsBroken, numDefaultBranchesBroken)
}
}
}
return err
}
func init() {
Register(&Check{
Title: "Synchronize repo HEADs",
Name: "synchronize-repo-heads",
IsDefault: true,
Run: synchronizeRepoHeads,
Priority: 7,
})
}
+51
View File
@@ -0,0 +1,51 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"fmt"
"time"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/services/repository"
)
func init() {
Register(&Check{
Title: "Garbage collect LFS",
Name: "gc-lfs",
IsDefault: false,
Run: garbageCollectLFSCheck,
AbortIfFailed: false,
SkipDatabaseInitialization: false,
Priority: 1,
})
}
func garbageCollectLFSCheck(ctx context.Context, logger log.Logger, autofix bool) error {
if !setting.LFS.StartServer {
return fmt.Errorf("LFS support is disabled")
}
if err := repository.GarbageCollectLFSMetaObjects(ctx, repository.GarbageCollectLFSMetaObjectsOptions{
LogDetail: logger.Info,
AutoFix: autofix,
// Only attempt to garbage collect lfs meta objects older than a week as the order of git lfs upload
// and git object upload is not necessarily guaranteed. It's possible to imagine a situation whereby
// an LFS object is uploaded but the git branch is not uploaded immediately, or there are some rapid
// changes in new branches that might lead to lfs objects becoming temporarily unassociated with git
// objects.
//
// It is likely that a week is potentially excessive but it should definitely be enough that any
// unassociated LFS object is genuinely unassociated.
OlderThan: time.Now().Add(-24 * time.Hour * 7),
// We don't set the UpdatedLessRecentlyThan because we want to do a full GC
}); err != nil {
return err
}
return checkStorage(&checkStorageOptions{LFS: true})(ctx, logger, autofix)
}
+114
View File
@@ -0,0 +1,114 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package doctor
import (
"context"
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
"xorm.io/builder"
)
func iteratePRs(ctx context.Context, repo *repo_model.Repository, each func(*repo_model.Repository, *issues_model.PullRequest) error) error {
return db.Iterate(
ctx,
builder.Eq{"base_repo_id": repo.ID},
func(ctx context.Context, bean *issues_model.PullRequest) error {
return each(repo, bean)
},
)
}
func checkPRMergeBase(ctx context.Context, logger log.Logger, autofix bool) error {
numRepos := 0
numPRs := 0
numPRsUpdated := 0
err := iterateRepositories(ctx, func(repo *repo_model.Repository) error {
numRepos++
return iteratePRs(ctx, repo, func(repo *repo_model.Repository, pr *issues_model.PullRequest) error {
numPRs++
pr.BaseRepo = repo
repoPath := repo.RepoPath()
oldMergeBase := pr.MergeBase
if !pr.HasMerged {
var err error
pr.MergeBase, _, err = git.NewCommand(ctx, "merge-base").AddDashesAndList(pr.BaseBranch, pr.GetGitRefName()).RunStdString(&git.RunOpts{Dir: repoPath})
if err != nil {
var err2 error
pr.MergeBase, _, err2 = git.NewCommand(ctx, "rev-parse").AddDynamicArguments(git.BranchPrefix + pr.BaseBranch).RunStdString(&git.RunOpts{Dir: repoPath})
if err2 != nil {
logger.Warn("Unable to get merge base for PR ID %d, #%d onto %s in %s/%s. Error: %v & %v", pr.ID, pr.Index, pr.BaseBranch, pr.BaseRepo.OwnerName, pr.BaseRepo.Name, err, err2)
return nil
}
}
} else {
parentsString, _, err := git.NewCommand(ctx, "rev-list", "--parents", "-n", "1").AddDynamicArguments(pr.MergedCommitID).RunStdString(&git.RunOpts{Dir: repoPath})
if err != nil {
logger.Warn("Unable to get parents for merged PR ID %d, #%d onto %s in %s/%s. Error: %v", pr.ID, pr.Index, pr.BaseBranch, pr.BaseRepo.OwnerName, pr.BaseRepo.Name, err)
return nil
}
parents := strings.Split(strings.TrimSpace(parentsString), " ")
if len(parents) < 2 {
return nil
}
refs := append([]string{}, parents[1:]...)
refs = append(refs, pr.GetGitRefName())
cmd := git.NewCommand(ctx, "merge-base").AddDashesAndList(refs...)
pr.MergeBase, _, err = cmd.RunStdString(&git.RunOpts{Dir: repoPath})
if err != nil {
logger.Warn("Unable to get merge base for merged PR ID %d, #%d onto %s in %s/%s. Error: %v", pr.ID, pr.Index, pr.BaseBranch, pr.BaseRepo.OwnerName, pr.BaseRepo.Name, err)
return nil
}
}
pr.MergeBase = strings.TrimSpace(pr.MergeBase)
if pr.MergeBase != oldMergeBase {
if autofix {
if err := pr.UpdateCols(ctx, "merge_base"); err != nil {
logger.Critical("Failed to update merge_base. ERROR: %v", err)
return fmt.Errorf("Failed to update merge_base. ERROR: %w", err)
}
} else {
logger.Info("#%d onto %s in %s/%s: MergeBase should be %s but is %s", pr.Index, pr.BaseBranch, pr.BaseRepo.OwnerName, pr.BaseRepo.Name, oldMergeBase, pr.MergeBase)
}
numPRsUpdated++
}
return nil
})
})
if autofix {
logger.Info("%d PR mergebases updated of %d PRs total in %d repos", numPRsUpdated, numPRs, numRepos)
} else {
if numPRsUpdated == 0 {
logger.Info("All %d PRs in %d repos have a correct mergebase", numPRs, numRepos)
} else if err == nil {
logger.Critical("%d PRs with incorrect mergebases of %d PRs total in %d repos", numPRsUpdated, numPRs, numRepos)
return fmt.Errorf("%d PRs with incorrect mergebases of %d PRs total in %d repos", numPRsUpdated, numPRs, numRepos)
} else {
logger.Warn("%d PRs with incorrect mergebases of %d PRs total in %d repos", numPRsUpdated, numPRs, numRepos)
}
}
return err
}
func init() {
Register(&Check{
Title: "Recalculate merge bases",
Name: "recalculate-merge-bases",
IsDefault: false,
Run: checkPRMergeBase,
Priority: 7,
})
}

Some files were not shown because too many files have changed in this diff Show More