mirror of
				https://github.com/go-gitea/gitea
				synced 2025-10-26 17:08:25 +00:00 
			
		
		
		
	| @@ -41,15 +41,15 @@ func init() { | ||||
| } | ||||
|  | ||||
| // GetSchedulesMapByIDs returns the schedules by given id slice. | ||||
| func GetSchedulesMapByIDs(ids []int64) (map[int64]*ActionSchedule, error) { | ||||
| func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) { | ||||
| 	schedules := make(map[int64]*ActionSchedule, len(ids)) | ||||
| 	return schedules, db.GetEngine(db.DefaultContext).In("id", ids).Find(&schedules) | ||||
| 	return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules) | ||||
| } | ||||
|  | ||||
| // GetReposMapByIDs returns the repos by given id slice. | ||||
| func GetReposMapByIDs(ids []int64) (map[int64]*repo_model.Repository, error) { | ||||
| func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) { | ||||
| 	repos := make(map[int64]*repo_model.Repository, len(ids)) | ||||
| 	return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos) | ||||
| 	return repos, db.GetEngine(ctx).In("id", ids).Find(&repos) | ||||
| } | ||||
|  | ||||
| var cronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) | ||||
|   | ||||
| @@ -23,9 +23,9 @@ func (specs SpecList) GetScheduleIDs() []int64 { | ||||
| 	return ids.Values() | ||||
| } | ||||
|  | ||||
| func (specs SpecList) LoadSchedules() error { | ||||
| func (specs SpecList) LoadSchedules(ctx context.Context) error { | ||||
| 	scheduleIDs := specs.GetScheduleIDs() | ||||
| 	schedules, err := GetSchedulesMapByIDs(scheduleIDs) | ||||
| 	schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -34,7 +34,7 @@ func (specs SpecList) LoadSchedules() error { | ||||
| 	} | ||||
|  | ||||
| 	repoIDs := specs.GetRepoIDs() | ||||
| 	repos, err := GetReposMapByIDs(repoIDs) | ||||
| 	repos, err := GetReposMapByIDs(ctx, repoIDs) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -95,7 +95,7 @@ func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, erro | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	if err := specs.LoadSchedules(); err != nil { | ||||
| 	if err := specs.LoadSchedules(ctx); err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
| 	return specs, total, nil | ||||
|   | ||||
| @@ -48,11 +48,7 @@ type TranslatableMessage struct { | ||||
| } | ||||
|  | ||||
| // LoadRepo loads repository of the task | ||||
| func (task *Task) LoadRepo() error { | ||||
| 	return task.loadRepo(db.DefaultContext) | ||||
| } | ||||
|  | ||||
| func (task *Task) loadRepo(ctx context.Context) error { | ||||
| func (task *Task) LoadRepo(ctx context.Context) error { | ||||
| 	if task.Repo != nil { | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -70,13 +66,13 @@ func (task *Task) loadRepo(ctx context.Context) error { | ||||
| } | ||||
|  | ||||
| // LoadDoer loads do user | ||||
| func (task *Task) LoadDoer() error { | ||||
| func (task *Task) LoadDoer(ctx context.Context) error { | ||||
| 	if task.Doer != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var doer user_model.User | ||||
| 	has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer) | ||||
| 	has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if !has { | ||||
| @@ -90,13 +86,13 @@ func (task *Task) LoadDoer() error { | ||||
| } | ||||
|  | ||||
| // LoadOwner loads owner user | ||||
| func (task *Task) LoadOwner() error { | ||||
| func (task *Task) LoadOwner(ctx context.Context) error { | ||||
| 	if task.Owner != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var owner user_model.User | ||||
| 	has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner) | ||||
| 	has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if !has { | ||||
| @@ -110,8 +106,8 @@ func (task *Task) LoadOwner() error { | ||||
| } | ||||
|  | ||||
| // UpdateCols updates some columns | ||||
| func (task *Task) UpdateCols(cols ...string) error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task) | ||||
| func (task *Task) UpdateCols(ctx context.Context, cols ...string) error { | ||||
| 	_, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -169,12 +165,12 @@ func (err ErrTaskDoesNotExist) Unwrap() error { | ||||
| } | ||||
|  | ||||
| // GetMigratingTask returns the migrating task by repo's id | ||||
| func GetMigratingTask(repoID int64) (*Task, error) { | ||||
| func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) { | ||||
| 	task := Task{ | ||||
| 		RepoID: repoID, | ||||
| 		Type:   structs.TaskTypeMigrateRepo, | ||||
| 	} | ||||
| 	has, err := db.GetEngine(db.DefaultContext).Get(&task) | ||||
| 	has, err := db.GetEngine(ctx).Get(&task) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} else if !has { | ||||
| @@ -184,13 +180,13 @@ func GetMigratingTask(repoID int64) (*Task, error) { | ||||
| } | ||||
|  | ||||
| // GetMigratingTaskByID returns the migrating task by repo's id | ||||
| func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, error) { | ||||
| func GetMigratingTaskByID(ctx context.Context, id, doerID int64) (*Task, *migration.MigrateOptions, error) { | ||||
| 	task := Task{ | ||||
| 		ID:     id, | ||||
| 		DoerID: doerID, | ||||
| 		Type:   structs.TaskTypeMigrateRepo, | ||||
| 	} | ||||
| 	has, err := db.GetEngine(db.DefaultContext).Get(&task) | ||||
| 	has, err := db.GetEngine(ctx).Get(&task) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} else if !has { | ||||
| @@ -205,12 +201,12 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e | ||||
| } | ||||
|  | ||||
| // CreateTask creates a task on database | ||||
| func CreateTask(task *Task) error { | ||||
| 	return db.Insert(db.DefaultContext, task) | ||||
| func CreateTask(ctx context.Context, task *Task) error { | ||||
| 	return db.Insert(ctx, task) | ||||
| } | ||||
|  | ||||
| // FinishMigrateTask updates database when migrate task finished | ||||
| func FinishMigrateTask(task *Task) error { | ||||
| func FinishMigrateTask(ctx context.Context, task *Task) error { | ||||
| 	task.Status = structs.TaskStatusFinished | ||||
| 	task.EndTime = timeutil.TimeStampNow() | ||||
|  | ||||
| @@ -231,6 +227,6 @@ func FinishMigrateTask(task *Task) error { | ||||
| 	} | ||||
| 	task.PayloadContent = string(confBytes) | ||||
|  | ||||
| 	_, err = db.GetEngine(db.DefaultContext).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task) | ||||
| 	_, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task) | ||||
| 	return err | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"code.gitea.io/gitea/models/db" | ||||
| @@ -22,8 +23,8 @@ func init() { | ||||
| } | ||||
|  | ||||
| // UpdateSession updates the session with provided id | ||||
| func UpdateSession(key string, data []byte) error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{ | ||||
| func UpdateSession(ctx context.Context, key string, data []byte) error { | ||||
| 	_, err := db.GetEngine(ctx).ID(key).Update(&Session{ | ||||
| 		Data:   data, | ||||
| 		Expiry: timeutil.TimeStampNow(), | ||||
| 	}) | ||||
| @@ -31,12 +32,12 @@ func UpdateSession(key string, data []byte) error { | ||||
| } | ||||
|  | ||||
| // ReadSession reads the data for the provided session | ||||
| func ReadSession(key string) (*Session, error) { | ||||
| func ReadSession(ctx context.Context, key string) (*Session, error) { | ||||
| 	session := Session{ | ||||
| 		Key: key, | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) { | ||||
| } | ||||
|  | ||||
| // ExistSession checks if a session exists | ||||
| func ExistSession(key string) (bool, error) { | ||||
| func ExistSession(ctx context.Context, key string) (bool, error) { | ||||
| 	session := Session{ | ||||
| 		Key: key, | ||||
| 	} | ||||
| 	return db.GetEngine(db.DefaultContext).Get(&session) | ||||
| 	return db.GetEngine(ctx).Get(&session) | ||||
| } | ||||
|  | ||||
| // DestroySession destroys a session | ||||
| func DestroySession(key string) error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).Delete(&Session{ | ||||
| func DestroySession(ctx context.Context, key string) error { | ||||
| 	_, err := db.GetEngine(ctx).Delete(&Session{ | ||||
| 		Key: key, | ||||
| 	}) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // RegenerateSession regenerates a session from the old id | ||||
| func RegenerateSession(oldKey, newKey string) (*Session, error) { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) { | ||||
| } | ||||
|  | ||||
| // CountSessions returns the number of sessions | ||||
| func CountSessions() (int64, error) { | ||||
| 	return db.GetEngine(db.DefaultContext).Count(&Session{}) | ||||
| func CountSessions(ctx context.Context) (int64, error) { | ||||
| 	return db.GetEngine(ctx).Count(&Session{}) | ||||
| } | ||||
|  | ||||
| // CleanupSessions cleans up expired sessions | ||||
| func CleanupSessions(maxLifetime int64) error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) | ||||
| func CleanupSessions(ctx context.Context, maxLifetime int64) error { | ||||
| 	_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) | ||||
| 	return err | ||||
| } | ||||
|   | ||||
| @@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string { | ||||
| } | ||||
|  | ||||
| // UpdateSignCount will update the database value of SignCount | ||||
| func (cred *WebAuthnCredential) UpdateSignCount() error { | ||||
| 	return cred.updateSignCount(db.DefaultContext) | ||||
| } | ||||
|  | ||||
| func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error { | ||||
| func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error { | ||||
| 	_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred) | ||||
| 	return err | ||||
| } | ||||
| @@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential { | ||||
| } | ||||
|  | ||||
| // GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user | ||||
| func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) { | ||||
| 	return getWebAuthnCredentialsByUID(db.DefaultContext, uid) | ||||
| } | ||||
|  | ||||
| func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { | ||||
| func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { | ||||
| 	creds := make(WebAuthnCredentialList, 0) | ||||
| 	return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds) | ||||
| } | ||||
|  | ||||
| // ExistsWebAuthnCredentialsForUID returns if the given user has credentials | ||||
| func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) { | ||||
| 	return existsWebAuthnCredentialsByUID(db.DefaultContext, uid) | ||||
| } | ||||
|  | ||||
| func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) { | ||||
| func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) { | ||||
| 	return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) | ||||
| } | ||||
|  | ||||
| // GetWebAuthnCredentialByName returns WebAuthn credential by id | ||||
| func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) { | ||||
| 	return getWebAuthnCredentialByName(db.DefaultContext, uid, name) | ||||
| } | ||||
|  | ||||
| func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { | ||||
| func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { | ||||
| 	cred := new(WebAuthnCredential) | ||||
| 	if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil { | ||||
| 		return nil, err | ||||
| @@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (* | ||||
| } | ||||
|  | ||||
| // GetWebAuthnCredentialByID returns WebAuthn credential by id | ||||
| func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) { | ||||
| 	return getWebAuthnCredentialByID(db.DefaultContext, id) | ||||
| } | ||||
|  | ||||
| func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { | ||||
| func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { | ||||
| 	cred := new(WebAuthnCredential) | ||||
| 	if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil { | ||||
| 		return nil, err | ||||
| @@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti | ||||
| } | ||||
|  | ||||
| // HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations | ||||
| func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) { | ||||
| 	return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) | ||||
| func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) { | ||||
| 	return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) | ||||
| } | ||||
|  | ||||
| // GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID | ||||
| func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) { | ||||
| 	return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID) | ||||
| } | ||||
|  | ||||
| func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { | ||||
| func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { | ||||
| 	cred := new(WebAuthnCredential) | ||||
| 	if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil { | ||||
| 		return nil, err | ||||
| @@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b | ||||
| } | ||||
|  | ||||
| // CreateCredential will create a new WebAuthnCredential from the given Credential | ||||
| func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { | ||||
| 	return createCredential(db.DefaultContext, userID, name, cred) | ||||
| } | ||||
|  | ||||
| func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { | ||||
| func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { | ||||
| 	c := &WebAuthnCredential{ | ||||
| 		UserID:          userID, | ||||
| 		Name:            name, | ||||
| @@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba | ||||
| } | ||||
|  | ||||
| // DeleteCredential will delete WebAuthnCredential | ||||
| func DeleteCredential(id, userID int64) (bool, error) { | ||||
| 	return deleteCredential(db.DefaultContext, id, userID) | ||||
| } | ||||
|  | ||||
| func deleteCredential(ctx context.Context, id, userID int64) (bool, error) { | ||||
| func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) { | ||||
| 	had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{}) | ||||
| 	return had > 0, err | ||||
| } | ||||
|  | ||||
| // WebAuthnCredentials implementns the webauthn.User interface | ||||
| func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) { | ||||
| 	dbCreds, err := GetWebAuthnCredentialsByUID(userID) | ||||
| func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) { | ||||
| 	dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	auth_model "code.gitea.io/gitea/models/auth" | ||||
| 	"code.gitea.io/gitea/models/db" | ||||
| 	"code.gitea.io/gitea/models/unittest" | ||||
|  | ||||
| 	"github.com/go-webauthn/webauthn/webauthn" | ||||
| @@ -16,11 +17,11 @@ import ( | ||||
| func TestGetWebAuthnCredentialByID(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	res, err := auth_model.GetWebAuthnCredentialByID(1) | ||||
| 	res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, "WebAuthn credential", res.Name) | ||||
|  | ||||
| 	_, err = auth_model.GetWebAuthnCredentialByID(342432) | ||||
| 	_, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432) | ||||
| 	assert.Error(t, err) | ||||
| 	assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err)) | ||||
| } | ||||
| @@ -28,7 +29,7 @@ func TestGetWebAuthnCredentialByID(t *testing.T) { | ||||
| func TestGetWebAuthnCredentialsByUID(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	res, err := auth_model.GetWebAuthnCredentialsByUID(32) | ||||
| 	res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, res, 1) | ||||
| 	assert.Equal(t, "WebAuthn credential", res[0].Name) | ||||
| @@ -42,7 +43,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) | ||||
| 	cred.SignCount = 1 | ||||
| 	assert.NoError(t, cred.UpdateSignCount()) | ||||
| 	assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) | ||||
| 	unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1}) | ||||
| } | ||||
|  | ||||
| @@ -50,14 +51,14 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) | ||||
| 	cred.SignCount = 0xffffffff | ||||
| 	assert.NoError(t, cred.UpdateSignCount()) | ||||
| 	assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) | ||||
| 	unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff}) | ||||
| } | ||||
|  | ||||
| func TestCreateCredential(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	res, err := auth_model.CreateCredential(1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")}) | ||||
| 	res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, "WebAuthn Created Credential", res.Name) | ||||
| 	assert.Equal(t, []byte("Test"), res.CredentialID) | ||||
|   | ||||
| @@ -385,7 +385,7 @@ func TestMilestoneList_LoadTotalTrackedTimes(t *testing.T) { | ||||
| 		unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}), | ||||
| 	} | ||||
|  | ||||
| 	assert.NoError(t, miles.LoadTotalTrackedTimes()) | ||||
| 	assert.NoError(t, miles.LoadTotalTrackedTimes(db.DefaultContext)) | ||||
|  | ||||
| 	assert.Equal(t, int64(3682), miles[0].TotalTrackedTime) | ||||
| } | ||||
| @@ -394,7 +394,7 @@ func TestLoadTotalTrackedTime(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) | ||||
|  | ||||
| 	assert.NoError(t, milestone.LoadTotalTrackedTime()) | ||||
| 	assert.NoError(t, milestone.LoadTotalTrackedTime(db.DefaultContext)) | ||||
|  | ||||
| 	assert.Equal(t, int64(3682), milestone.TotalTrackedTime) | ||||
| } | ||||
|   | ||||
| @@ -30,8 +30,8 @@ func init() { | ||||
| type IssueWatchList []*IssueWatch | ||||
|  | ||||
| // CreateOrUpdateIssueWatch set watching for a user and issue | ||||
| func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { | ||||
| 	iw, exists, err := GetIssueWatch(db.DefaultContext, userID, issueID) | ||||
| func CreateOrUpdateIssueWatch(ctx context.Context, userID, issueID int64, isWatching bool) error { | ||||
| 	iw, exists, err := GetIssueWatch(ctx, userID, issueID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -43,13 +43,13 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { | ||||
| 			IsWatching: isWatching, | ||||
| 		} | ||||
|  | ||||
| 		if _, err := db.GetEngine(db.DefaultContext).Insert(iw); err != nil { | ||||
| 		if _, err := db.GetEngine(ctx).Insert(iw); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		iw.IsWatching = isWatching | ||||
|  | ||||
| 		if _, err := db.GetEngine(db.DefaultContext).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { | ||||
| 		if _, err := db.GetEngine(ctx).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| @@ -69,15 +69,15 @@ func GetIssueWatch(ctx context.Context, userID, issueID int64) (iw *IssueWatch, | ||||
|  | ||||
| // CheckIssueWatch check if an user is watching an issue | ||||
| // it takes participants and repo watch into account | ||||
| func CheckIssueWatch(user *user_model.User, issue *Issue) (bool, error) { | ||||
| 	iw, exist, err := GetIssueWatch(db.DefaultContext, user.ID, issue.ID) | ||||
| func CheckIssueWatch(ctx context.Context, user *user_model.User, issue *Issue) (bool, error) { | ||||
| 	iw, exist, err := GetIssueWatch(ctx, user.ID, issue.ID) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	if exist { | ||||
| 		return iw.IsWatching, nil | ||||
| 	} | ||||
| 	w, err := repo_model.GetWatch(db.DefaultContext, user.ID, issue.RepoID) | ||||
| 	w, err := repo_model.GetWatch(ctx, user.ID, issue.RepoID) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|   | ||||
| @@ -16,11 +16,11 @@ import ( | ||||
| func TestCreateOrUpdateIssueWatch(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(3, 1, true)) | ||||
| 	assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 3, 1, true)) | ||||
| 	iw := unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 3, IssueID: 1}) | ||||
| 	assert.True(t, iw.IsWatching) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(1, 1, false)) | ||||
| 	assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 1, 1, false)) | ||||
| 	iw = unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 1, IssueID: 1}) | ||||
| 	assert.False(t, iw.IsWatching) | ||||
| } | ||||
|   | ||||
| @@ -199,8 +199,8 @@ func NewLabel(ctx context.Context, l *Label) error { | ||||
| } | ||||
|  | ||||
| // NewLabels creates new labels | ||||
| func NewLabels(labels ...*Label) error { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func NewLabels(ctx context.Context, labels ...*Label) error { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -221,19 +221,19 @@ func NewLabels(labels ...*Label) error { | ||||
| } | ||||
|  | ||||
| // UpdateLabel updates label information. | ||||
| func UpdateLabel(l *Label) error { | ||||
| func UpdateLabel(ctx context.Context, l *Label) error { | ||||
| 	color, err := label.NormalizeColor(l.Color) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	l.Color = color | ||||
|  | ||||
| 	return updateLabelCols(db.DefaultContext, l, "name", "description", "color", "exclusive", "archived_unix") | ||||
| 	return updateLabelCols(ctx, l, "name", "description", "color", "exclusive", "archived_unix") | ||||
| } | ||||
|  | ||||
| // DeleteLabel delete a label | ||||
| func DeleteLabel(id, labelID int64) error { | ||||
| 	l, err := GetLabelByID(db.DefaultContext, labelID) | ||||
| func DeleteLabel(ctx context.Context, id, labelID int64) error { | ||||
| 	l, err := GetLabelByID(ctx, labelID) | ||||
| 	if err != nil { | ||||
| 		if IsErrLabelNotExist(err) { | ||||
| 			return nil | ||||
| @@ -241,7 +241,7 @@ func DeleteLabel(id, labelID int64) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -289,9 +289,9 @@ func GetLabelByID(ctx context.Context, labelID int64) (*Label, error) { | ||||
| } | ||||
|  | ||||
| // GetLabelsByIDs returns a list of labels by IDs | ||||
| func GetLabelsByIDs(labelIDs []int64, cols ...string) ([]*Label, error) { | ||||
| func GetLabelsByIDs(ctx context.Context, labelIDs []int64, cols ...string) ([]*Label, error) { | ||||
| 	labels := make([]*Label, 0, len(labelIDs)) | ||||
| 	return labels, db.GetEngine(db.DefaultContext).Table("label"). | ||||
| 	return labels, db.GetEngine(ctx).Table("label"). | ||||
| 		In("id", labelIDs). | ||||
| 		Asc("name"). | ||||
| 		Cols(cols...). | ||||
| @@ -339,9 +339,9 @@ func GetLabelInRepoByID(ctx context.Context, repoID, labelID int64) (*Label, err | ||||
| // GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given | ||||
| // repository. | ||||
| // it silently ignores label names that do not belong to the repository. | ||||
| func GetLabelIDsInRepoByNames(repoID int64, labelNames []string) ([]int64, error) { | ||||
| func GetLabelIDsInRepoByNames(ctx context.Context, repoID int64, labelNames []string) ([]int64, error) { | ||||
| 	labelIDs := make([]int64, 0, len(labelNames)) | ||||
| 	return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). | ||||
| 	return labelIDs, db.GetEngine(ctx).Table("label"). | ||||
| 		Where("repo_id = ?", repoID). | ||||
| 		In("name", labelNames). | ||||
| 		Asc("name"). | ||||
| @@ -398,8 +398,8 @@ func GetLabelsByRepoID(ctx context.Context, repoID int64, sortType string, listO | ||||
| } | ||||
|  | ||||
| // CountLabelsByRepoID count number of all labels that belong to given repository by ID. | ||||
| func CountLabelsByRepoID(repoID int64) (int64, error) { | ||||
| 	return db.GetEngine(db.DefaultContext).Where("repo_id = ?", repoID).Count(&Label{}) | ||||
| func CountLabelsByRepoID(ctx context.Context, repoID int64) (int64, error) { | ||||
| 	return db.GetEngine(ctx).Where("repo_id = ?", repoID).Count(&Label{}) | ||||
| } | ||||
|  | ||||
| // GetLabelInOrgByName returns a label by name in given organization. | ||||
| @@ -442,13 +442,13 @@ func GetLabelInOrgByID(ctx context.Context, orgID, labelID int64) (*Label, error | ||||
|  | ||||
| // GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given | ||||
| // organization. | ||||
| func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) { | ||||
| func GetLabelIDsInOrgByNames(ctx context.Context, orgID int64, labelNames []string) ([]int64, error) { | ||||
| 	if orgID <= 0 { | ||||
| 		return nil, ErrOrgLabelNotExist{0, orgID} | ||||
| 	} | ||||
| 	labelIDs := make([]int64, 0, len(labelNames)) | ||||
|  | ||||
| 	return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). | ||||
| 	return labelIDs, db.GetEngine(ctx).Table("label"). | ||||
| 		Where("org_id = ?", orgID). | ||||
| 		In("name", labelNames). | ||||
| 		Asc("name"). | ||||
| @@ -506,8 +506,8 @@ func GetLabelIDsByNames(ctx context.Context, labelNames []string) ([]int64, erro | ||||
| } | ||||
|  | ||||
| // CountLabelsByOrgID count all labels that belong to given organization by ID. | ||||
| func CountLabelsByOrgID(orgID int64) (int64, error) { | ||||
| 	return db.GetEngine(db.DefaultContext).Where("org_id = ?", orgID).Count(&Label{}) | ||||
| func CountLabelsByOrgID(ctx context.Context, orgID int64) (int64, error) { | ||||
| 	return db.GetEngine(ctx).Where("org_id = ?", orgID).Count(&Label{}) | ||||
| } | ||||
|  | ||||
| func updateLabelCols(ctx context.Context, l *Label, cols ...string) error { | ||||
|   | ||||
| @@ -48,7 +48,7 @@ func TestNewLabels(t *testing.T) { | ||||
| 	for _, label := range labels { | ||||
| 		unittest.AssertNotExistsBean(t, label) | ||||
| 	} | ||||
| 	assert.NoError(t, issues_model.NewLabels(labels...)) | ||||
| 	assert.NoError(t, issues_model.NewLabels(db.DefaultContext, labels...)) | ||||
| 	for _, label := range labels { | ||||
| 		unittest.AssertExistsAndLoadBean(t, label, unittest.Cond("id = ?", label.ID)) | ||||
| 	} | ||||
| @@ -81,7 +81,7 @@ func TestGetLabelInRepoByName(t *testing.T) { | ||||
|  | ||||
| func TestGetLabelInRepoByNames(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2"}) | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2"}) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assert.Len(t, labelIDs, 2) | ||||
| @@ -93,7 +93,7 @@ func TestGetLabelInRepoByNames(t *testing.T) { | ||||
| func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	// label3 doesn't exists.. See labels.yml | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2", "label3"}) | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2", "label3"}) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assert.Len(t, labelIDs, 2) | ||||
| @@ -166,7 +166,7 @@ func TestGetLabelInOrgByName(t *testing.T) { | ||||
|  | ||||
| func TestGetLabelInOrgByNames(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4"}) | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4"}) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assert.Len(t, labelIDs, 2) | ||||
| @@ -178,7 +178,7 @@ func TestGetLabelInOrgByNames(t *testing.T) { | ||||
| func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	// orglabel99 doesn't exists.. See labels.yml | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4", "orglabel99"}) | ||||
| 	labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4", "orglabel99"}) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assert.Len(t, labelIDs, 2) | ||||
| @@ -269,7 +269,7 @@ func TestUpdateLabel(t *testing.T) { | ||||
| 	} | ||||
| 	label.Color = update.Color | ||||
| 	label.Name = update.Name | ||||
| 	assert.NoError(t, issues_model.UpdateLabel(update)) | ||||
| 	assert.NoError(t, issues_model.UpdateLabel(db.DefaultContext, update)) | ||||
| 	newLabel := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1}) | ||||
| 	assert.EqualValues(t, label.ID, newLabel.ID) | ||||
| 	assert.EqualValues(t, label.Color, newLabel.Color) | ||||
| @@ -282,13 +282,13 @@ func TestUpdateLabel(t *testing.T) { | ||||
| func TestDeleteLabel(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	label := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1}) | ||||
| 	assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID)) | ||||
| 	assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID)) | ||||
| 	unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID, RepoID: label.RepoID}) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID)) | ||||
| 	assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID)) | ||||
| 	unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID}) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.DeleteLabel(unittest.NonexistentID, unittest.NonexistentID)) | ||||
| 	assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) | ||||
| 	unittest.CheckConsistencyFor(t, &issues_model.Label{}, &repo_model.Repository{}) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -103,8 +103,8 @@ func (m *Milestone) State() api.StateType { | ||||
| } | ||||
|  | ||||
| // NewMilestone creates new milestone of repository. | ||||
| func NewMilestone(m *Milestone) (err error) { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func NewMilestone(ctx context.Context, m *Milestone) (err error) { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -140,9 +140,9 @@ func GetMilestoneByRepoID(ctx context.Context, repoID, id int64) (*Milestone, er | ||||
| } | ||||
|  | ||||
| // GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo | ||||
| func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) { | ||||
| func GetMilestoneByRepoIDANDName(ctx context.Context, repoID int64, name string) (*Milestone, error) { | ||||
| 	var mile Milestone | ||||
| 	has, err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND name=?", repoID, name).Get(&mile) | ||||
| 	has, err := db.GetEngine(ctx).Where("repo_id=? AND name=?", repoID, name).Get(&mile) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -153,8 +153,8 @@ func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) | ||||
| } | ||||
|  | ||||
| // UpdateMilestone updates information of given milestone. | ||||
| func UpdateMilestone(m *Milestone, oldIsClosed bool) error { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func UpdateMilestone(ctx context.Context, m *Milestone, oldIsClosed bool) error { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -211,8 +211,8 @@ func UpdateMilestoneCounters(ctx context.Context, id int64) error { | ||||
| } | ||||
|  | ||||
| // ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo. | ||||
| func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func ChangeMilestoneStatusByRepoIDAndID(ctx context.Context, repoID, milestoneID int64, isClosed bool) error { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -238,8 +238,8 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool | ||||
| } | ||||
|  | ||||
| // ChangeMilestoneStatus changes the milestone open/closed status. | ||||
| func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func ChangeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) (err error) { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -269,8 +269,8 @@ func changeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) err | ||||
| } | ||||
|  | ||||
| // DeleteMilestoneByRepoID deletes a milestone from a repository. | ||||
| func DeleteMilestoneByRepoID(repoID, id int64) error { | ||||
| 	m, err := GetMilestoneByRepoID(db.DefaultContext, repoID, id) | ||||
| func DeleteMilestoneByRepoID(ctx context.Context, repoID, id int64) error { | ||||
| 	m, err := GetMilestoneByRepoID(ctx, repoID, id) | ||||
| 	if err != nil { | ||||
| 		if IsErrMilestoneNotExist(err) { | ||||
| 			return nil | ||||
| @@ -278,12 +278,12 @@ func DeleteMilestoneByRepoID(repoID, id int64) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	repo, err := repo_model.GetRepositoryByID(db.DefaultContext, m.RepoID) | ||||
| 	repo, err := repo_model.GetRepositoryByID(ctx, m.RepoID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -332,7 +332,8 @@ func updateRepoMilestoneNum(ctx context.Context, repoID int64) error { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { | ||||
| // LoadTotalTrackedTime loads the tracked time for the milestone | ||||
| func (m *Milestone) LoadTotalTrackedTime(ctx context.Context) error { | ||||
| 	type totalTimesByMilestone struct { | ||||
| 		MilestoneID int64 | ||||
| 		Time        int64 | ||||
| @@ -355,18 +356,13 @@ func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // LoadTotalTrackedTime loads the tracked time for the milestone | ||||
| func (m *Milestone) LoadTotalTrackedTime() error { | ||||
| 	return m.loadTotalTrackedTime(db.DefaultContext) | ||||
| } | ||||
|  | ||||
| // InsertMilestones creates milestones of repository. | ||||
| func InsertMilestones(ms ...*Milestone) (err error) { | ||||
| func InsertMilestones(ctx context.Context, ms ...*Milestone) (err error) { | ||||
| 	if len(ms) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -100,9 +100,9 @@ func GetMilestoneIDsByNames(ctx context.Context, names []string) ([]int64, error | ||||
| } | ||||
|  | ||||
| // SearchMilestones search milestones | ||||
| func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) { | ||||
| func SearchMilestones(ctx context.Context, repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) { | ||||
| 	miles := make([]*Milestone, 0, setting.UI.IssuePagingNum) | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) | ||||
| 	sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) | ||||
| 	if len(keyword) > 0 { | ||||
| 		sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) | ||||
| 	} | ||||
| @@ -131,8 +131,9 @@ func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, | ||||
| } | ||||
|  | ||||
| // GetMilestonesByRepoIDs returns a list of milestones of given repositories and status. | ||||
| func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { | ||||
| func GetMilestonesByRepoIDs(ctx context.Context, repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { | ||||
| 	return SearchMilestones( | ||||
| 		ctx, | ||||
| 		builder.In("repo_id", repoIDs), | ||||
| 		page, | ||||
| 		isClosed, | ||||
| @@ -141,7 +142,8 @@ func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType s | ||||
| 	) | ||||
| } | ||||
|  | ||||
| func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error { | ||||
| // LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request | ||||
| func (milestones MilestoneList) LoadTotalTrackedTimes(ctx context.Context) error { | ||||
| 	type totalTimesByMilestone struct { | ||||
| 		MilestoneID int64 | ||||
| 		Time        int64 | ||||
| @@ -181,11 +183,6 @@ func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request | ||||
| func (milestones MilestoneList) LoadTotalTrackedTimes() error { | ||||
| 	return milestones.loadTotalTrackedTimes(db.DefaultContext) | ||||
| } | ||||
|  | ||||
| // CountMilestones returns number of milestones in given repository with other options | ||||
| func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, error) { | ||||
| 	return db.GetEngine(ctx). | ||||
| @@ -194,8 +191,8 @@ func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, erro | ||||
| } | ||||
|  | ||||
| // CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options` | ||||
| func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) | ||||
| func CountMilestonesByRepoCond(ctx context.Context, repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { | ||||
| 	sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) | ||||
| 	if repoCond.IsValid() { | ||||
| 		sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond)) | ||||
| 	} | ||||
| @@ -219,8 +216,8 @@ func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64] | ||||
| } | ||||
|  | ||||
| // CountMilestonesByRepoCondAndKw map from repo conditions and the keyword of milestones' name to number of milestones matching the options` | ||||
| func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) | ||||
| func CountMilestonesByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { | ||||
| 	sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) | ||||
| 	if len(keyword) > 0 { | ||||
| 		sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) | ||||
| 	} | ||||
| @@ -257,11 +254,11 @@ func (m MilestonesStats) Total() int64 { | ||||
| } | ||||
|  | ||||
| // GetMilestonesStatsByRepoCond returns milestone statistic information for dashboard by given conditions. | ||||
| func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, error) { | ||||
| func GetMilestonesStatsByRepoCond(ctx context.Context, repoCond builder.Cond) (*MilestonesStats, error) { | ||||
| 	var err error | ||||
| 	stats := &MilestonesStats{} | ||||
|  | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) | ||||
| 	sess := db.GetEngine(ctx).Where("is_closed = ?", false) | ||||
| 	if repoCond.IsValid() { | ||||
| 		sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) | ||||
| 	} | ||||
| @@ -270,7 +267,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) | ||||
| 	sess = db.GetEngine(ctx).Where("is_closed = ?", true) | ||||
| 	if repoCond.IsValid() { | ||||
| 		sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) | ||||
| 	} | ||||
| @@ -283,11 +280,11 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro | ||||
| } | ||||
|  | ||||
| // GetMilestonesStatsByRepoCondAndKw returns milestone statistic information for dashboard by given repo conditions and name keyword. | ||||
| func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*MilestonesStats, error) { | ||||
| func GetMilestonesStatsByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string) (*MilestonesStats, error) { | ||||
| 	var err error | ||||
| 	stats := &MilestonesStats{} | ||||
|  | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) | ||||
| 	sess := db.GetEngine(ctx).Where("is_closed = ?", false) | ||||
| 	if len(keyword) > 0 { | ||||
| 		sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) | ||||
| 	} | ||||
| @@ -299,7 +296,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (* | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) | ||||
| 	sess = db.GetEngine(ctx).Where("is_closed = ?", true) | ||||
| 	if len(keyword) > 0 { | ||||
| 		sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) | ||||
| 	} | ||||
|   | ||||
| @@ -201,12 +201,12 @@ func TestCountMilestonesByRepoIDs(t *testing.T) { | ||||
| 	repo1OpenCount, repo1ClosedCount := milestonesCount(1) | ||||
| 	repo2OpenCount, repo2ClosedCount := milestonesCount(2) | ||||
|  | ||||
| 	openCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), false) | ||||
| 	openCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), false) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.EqualValues(t, repo1OpenCount, openCounts[1]) | ||||
| 	assert.EqualValues(t, repo2OpenCount, openCounts[2]) | ||||
|  | ||||
| 	closedCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), true) | ||||
| 	closedCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), true) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.EqualValues(t, repo1ClosedCount, closedCounts[1]) | ||||
| 	assert.EqualValues(t, repo2ClosedCount, closedCounts[2]) | ||||
| @@ -218,7 +218,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) { | ||||
| 	repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) | ||||
| 	test := func(sortType string, sortCond func(*issues_model.Milestone) int) { | ||||
| 		for _, page := range []int{0, 1} { | ||||
| 			openMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, false, sortType) | ||||
| 			openMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, false, sortType) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Len(t, openMilestones, repo1.NumOpenMilestones+repo2.NumOpenMilestones) | ||||
| 			values := make([]int, len(openMilestones)) | ||||
| @@ -227,7 +227,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) { | ||||
| 			} | ||||
| 			assert.True(t, sort.IntsAreSorted(values)) | ||||
|  | ||||
| 			closedMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, true, sortType) | ||||
| 			closedMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, true, sortType) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Len(t, closedMilestones, repo1.NumClosedMilestones+repo2.NumClosedMilestones) | ||||
| 			values = make([]int, len(closedMilestones)) | ||||
| @@ -262,7 +262,7 @@ func TestGetMilestonesStats(t *testing.T) { | ||||
|  | ||||
| 	test := func(repoID int64) { | ||||
| 		repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) | ||||
| 		stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": repoID})) | ||||
| 		stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": repoID})) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.EqualValues(t, repo.NumMilestones-repo.NumClosedMilestones, stats.OpenCount) | ||||
| 		assert.EqualValues(t, repo.NumClosedMilestones, stats.ClosedCount) | ||||
| @@ -271,7 +271,7 @@ func TestGetMilestonesStats(t *testing.T) { | ||||
| 	test(2) | ||||
| 	test(3) | ||||
|  | ||||
| 	stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": unittest.NonexistentID})) | ||||
| 	stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": unittest.NonexistentID})) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.EqualValues(t, 0, stats.OpenCount) | ||||
| 	assert.EqualValues(t, 0, stats.ClosedCount) | ||||
| @@ -279,7 +279,7 @@ func TestGetMilestonesStats(t *testing.T) { | ||||
| 	repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) | ||||
| 	repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) | ||||
|  | ||||
| 	milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(builder.In("repo_id", []int64{repo1.ID, repo2.ID})) | ||||
| 	milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{repo1.ID, repo2.ID})) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.EqualValues(t, repo1.NumOpenMilestones+repo2.NumOpenMilestones, milestoneStats.OpenCount) | ||||
| 	assert.EqualValues(t, repo1.NumClosedMilestones+repo2.NumClosedMilestones, milestoneStats.ClosedCount) | ||||
| @@ -293,7 +293,7 @@ func TestNewMilestone(t *testing.T) { | ||||
| 		Content: "milestoneContent", | ||||
| 	} | ||||
|  | ||||
| 	assert.NoError(t, issues_model.NewMilestone(milestone)) | ||||
| 	assert.NoError(t, issues_model.NewMilestone(db.DefaultContext, milestone)) | ||||
| 	unittest.AssertExistsAndLoadBean(t, milestone) | ||||
| 	unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) | ||||
| } | ||||
| @@ -302,22 +302,22 @@ func TestChangeMilestoneStatus(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, true)) | ||||
| 	assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, true)) | ||||
| 	unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=1") | ||||
| 	unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, false)) | ||||
| 	assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, false)) | ||||
| 	unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=0") | ||||
| 	unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) | ||||
| } | ||||
|  | ||||
| func TestDeleteMilestoneByRepoID(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	assert.NoError(t, issues_model.DeleteMilestoneByRepoID(1, 1)) | ||||
| 	assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, 1, 1)) | ||||
| 	unittest.AssertNotExistsBean(t, &issues_model.Milestone{ID: 1}) | ||||
| 	unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: 1}) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.DeleteMilestoneByRepoID(unittest.NonexistentID, unittest.NonexistentID)) | ||||
| 	assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) | ||||
| } | ||||
|  | ||||
| func TestUpdateMilestone(t *testing.T) { | ||||
| @@ -326,7 +326,7 @@ func TestUpdateMilestone(t *testing.T) { | ||||
| 	milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) | ||||
| 	milestone.Name = " newMilestoneName  " | ||||
| 	milestone.Content = "newMilestoneContent" | ||||
| 	assert.NoError(t, issues_model.UpdateMilestone(milestone, milestone.IsClosed)) | ||||
| 	assert.NoError(t, issues_model.UpdateMilestone(db.DefaultContext, milestone, milestone.IsClosed)) | ||||
| 	milestone = unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) | ||||
| 	assert.EqualValues(t, "newMilestoneName", milestone.Name) | ||||
| 	unittest.CheckConsistencyFor(t, &issues_model.Milestone{}) | ||||
| @@ -361,7 +361,7 @@ func TestMigrate_InsertMilestones(t *testing.T) { | ||||
| 		RepoID: repo.ID, | ||||
| 		Name:   name, | ||||
| 	} | ||||
| 	err := issues_model.InsertMilestones(ms) | ||||
| 	err := issues_model.InsertMilestones(db.DefaultContext, ms) | ||||
| 	assert.NoError(t, err) | ||||
| 	unittest.AssertExistsAndLoadBean(t, ms) | ||||
| 	repoModified := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repo.ID}) | ||||
|   | ||||
| @@ -81,9 +81,9 @@ type UserStopwatch struct { | ||||
| } | ||||
|  | ||||
| // GetUIDsAndNotificationCounts between the two provided times | ||||
| func GetUIDsAndStopwatch() ([]*UserStopwatch, error) { | ||||
| func GetUIDsAndStopwatch(ctx context.Context) ([]*UserStopwatch, error) { | ||||
| 	sws := []*Stopwatch{} | ||||
| 	if err := db.GetEngine(db.DefaultContext).Where("issue_id != 0").Find(&sws); err != nil { | ||||
| 	if err := db.GetEngine(ctx).Where("issue_id != 0").Find(&sws); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(sws) == 0 { | ||||
| @@ -107,9 +107,9 @@ func GetUIDsAndStopwatch() ([]*UserStopwatch, error) { | ||||
| } | ||||
|  | ||||
| // GetUserStopwatches return list of all stopwatches of a user | ||||
| func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { | ||||
| func GetUserStopwatches(ctx context.Context, userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { | ||||
| 	sws := make([]*Stopwatch, 0, 8) | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where("stopwatch.user_id = ?", userID) | ||||
| 	sess := db.GetEngine(ctx).Where("stopwatch.user_id = ?", userID) | ||||
| 	if listOptions.Page != 0 { | ||||
| 		sess = db.SetSessionPagination(sess, &listOptions) | ||||
| 	} | ||||
| @@ -122,13 +122,13 @@ func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, | ||||
| } | ||||
|  | ||||
| // CountUserStopwatches return count of all stopwatches of a user | ||||
| func CountUserStopwatches(userID int64) (int64, error) { | ||||
| 	return db.GetEngine(db.DefaultContext).Where("user_id = ?", userID).Count(&Stopwatch{}) | ||||
| func CountUserStopwatches(ctx context.Context, userID int64) (int64, error) { | ||||
| 	return db.GetEngine(ctx).Where("user_id = ?", userID).Count(&Stopwatch{}) | ||||
| } | ||||
|  | ||||
| // StopwatchExists returns true if the stopwatch exists | ||||
| func StopwatchExists(userID, issueID int64) bool { | ||||
| 	_, exists, _ := getStopwatch(db.DefaultContext, userID, issueID) | ||||
| func StopwatchExists(ctx context.Context, userID, issueID int64) bool { | ||||
| 	_, exists, _ := getStopwatch(ctx, userID, issueID) | ||||
| 	return exists | ||||
| } | ||||
|  | ||||
| @@ -168,15 +168,15 @@ func FinishIssueStopwatchIfPossible(ctx context.Context, user *user_model.User, | ||||
| } | ||||
|  | ||||
| // CreateOrStopIssueStopwatch create an issue stopwatch if it's not exist, otherwise finish it | ||||
| func CreateOrStopIssueStopwatch(user *user_model.User, issue *Issue) error { | ||||
| 	_, exists, err := getStopwatch(db.DefaultContext, user.ID, issue.ID) | ||||
| func CreateOrStopIssueStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { | ||||
| 	_, exists, err := getStopwatch(ctx, user.ID, issue.ID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if exists { | ||||
| 		return FinishIssueStopwatch(db.DefaultContext, user, issue) | ||||
| 		return FinishIssueStopwatch(ctx, user, issue) | ||||
| 	} | ||||
| 	return CreateIssueStopwatch(db.DefaultContext, user, issue) | ||||
| 	return CreateIssueStopwatch(ctx, user, issue) | ||||
| } | ||||
|  | ||||
| // FinishIssueStopwatch if stopwatch exist then finish it otherwise return an error | ||||
| @@ -269,8 +269,8 @@ func CreateIssueStopwatch(ctx context.Context, user *user_model.User, issue *Iss | ||||
| } | ||||
|  | ||||
| // CancelStopwatch removes the given stopwatch and logs it into issue's timeline. | ||||
| func CancelStopwatch(user *user_model.User, issue *Issue) error { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func CancelStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -26,20 +26,20 @@ func TestCancelStopwatch(t *testing.T) { | ||||
| 	issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	err = issues_model.CancelStopwatch(user1, issue1) | ||||
| 	err = issues_model.CancelStopwatch(db.DefaultContext, user1, issue1) | ||||
| 	assert.NoError(t, err) | ||||
| 	unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: user1.ID, IssueID: issue1.ID}) | ||||
|  | ||||
| 	_ = unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{Type: issues_model.CommentTypeCancelTracking, PosterID: user1.ID, IssueID: issue1.ID}) | ||||
|  | ||||
| 	assert.Nil(t, issues_model.CancelStopwatch(user1, issue2)) | ||||
| 	assert.Nil(t, issues_model.CancelStopwatch(db.DefaultContext, user1, issue2)) | ||||
| } | ||||
|  | ||||
| func TestStopwatchExists(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	assert.True(t, issues_model.StopwatchExists(1, 1)) | ||||
| 	assert.False(t, issues_model.StopwatchExists(1, 2)) | ||||
| 	assert.True(t, issues_model.StopwatchExists(db.DefaultContext, 1, 1)) | ||||
| 	assert.False(t, issues_model.StopwatchExists(db.DefaultContext, 1, 2)) | ||||
| } | ||||
|  | ||||
| func TestHasUserStopwatch(t *testing.T) { | ||||
| @@ -68,11 +68,11 @@ func TestCreateOrStopIssueStopwatch(t *testing.T) { | ||||
| 	issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(org3, issue1)) | ||||
| 	assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(db.DefaultContext, org3, issue1)) | ||||
| 	sw := unittest.AssertExistsAndLoadBean(t, &issues_model.Stopwatch{UserID: 3, IssueID: 1}) | ||||
| 	assert.LessOrEqual(t, sw.CreatedUnix, timeutil.TimeStampNow()) | ||||
|  | ||||
| 	assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(user2, issue2)) | ||||
| 	assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(db.DefaultContext, user2, issue2)) | ||||
| 	unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: 2, IssueID: 2}) | ||||
| 	unittest.AssertExistsAndLoadBean(t, &issues_model.TrackedTime{UserID: 2, IssueID: 2}) | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| package organization | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| @@ -19,7 +20,7 @@ import ( | ||||
| type MinimalOrg = Organization | ||||
|  | ||||
| // GetUserOrgsList returns all organizations the given user has access to | ||||
| func GetUserOrgsList(user *user_model.User) ([]*MinimalOrg, error) { | ||||
| func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, error) { | ||||
| 	schema, err := db.TableInfo(new(user_model.User)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -42,7 +43,7 @@ func GetUserOrgsList(user *user_model.User) ([]*MinimalOrg, error) { | ||||
| 	groupByStr := groupByCols.String() | ||||
| 	groupByStr = groupByStr[0 : len(groupByStr)-1] | ||||
|  | ||||
| 	sess := db.GetEngine(db.DefaultContext) | ||||
| 	sess := db.GetEngine(ctx) | ||||
| 	sess = sess.Select(groupByStr+", count(distinct repo_id) as org_count"). | ||||
| 		Table("user"). | ||||
| 		Join("INNER", "team", "`team`.org_id = `user`.id"). | ||||
|   | ||||
| @@ -72,7 +72,7 @@ var delRepoArchiver = new(RepoArchiver) | ||||
|  | ||||
| // DeleteRepoArchiver delete archiver | ||||
| func DeleteRepoArchiver(ctx context.Context, archiver *RepoArchiver) error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).ID(archiver.ID).Delete(delRepoArchiver) | ||||
| 	_, err := db.GetEngine(ctx).ID(archiver.ID).Delete(delRepoArchiver) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -113,8 +113,8 @@ func UpdateRepoArchiverStatus(ctx context.Context, archiver *RepoArchiver) error | ||||
| } | ||||
|  | ||||
| // DeleteAllRepoArchives deletes all repo archives records | ||||
| func DeleteAllRepoArchives() error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).Where("1=1").Delete(new(RepoArchiver)) | ||||
| func DeleteAllRepoArchives(ctx context.Context) error { | ||||
| 	_, err := db.GetEngine(ctx).Where("1=1").Delete(new(RepoArchiver)) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -133,10 +133,10 @@ func (opts FindRepoArchiversOption) toConds() builder.Cond { | ||||
| } | ||||
|  | ||||
| // FindRepoArchives find repo archivers | ||||
| func FindRepoArchives(opts FindRepoArchiversOption) ([]*RepoArchiver, error) { | ||||
| func FindRepoArchives(ctx context.Context, opts FindRepoArchiversOption) ([]*RepoArchiver, error) { | ||||
| 	archivers := make([]*RepoArchiver, 0, opts.PageSize) | ||||
| 	start, limit := opts.GetSkipTake() | ||||
| 	err := db.GetEngine(db.DefaultContext).Where(opts.toConds()). | ||||
| 	err := db.GetEngine(ctx).Where(opts.toConds()). | ||||
| 		Asc("created_unix"). | ||||
| 		Limit(limit, start). | ||||
| 		Find(&archivers) | ||||
| @@ -144,7 +144,7 @@ func FindRepoArchives(opts FindRepoArchiversOption) ([]*RepoArchiver, error) { | ||||
| } | ||||
|  | ||||
| // SetArchiveRepoState sets if a repo is archived | ||||
| func SetArchiveRepoState(repo *Repository, isArchived bool) (err error) { | ||||
| func SetArchiveRepoState(ctx context.Context, repo *Repository, isArchived bool) (err error) { | ||||
| 	repo.IsArchived = isArchived | ||||
|  | ||||
| 	if isArchived { | ||||
| @@ -153,6 +153,6 @@ func SetArchiveRepoState(repo *Repository, isArchived bool) (err error) { | ||||
| 		repo.ArchivedUnix = timeutil.TimeStamp(0) | ||||
| 	} | ||||
|  | ||||
| 	_, err = db.GetEngine(db.DefaultContext).ID(repo.ID).Cols("is_archived", "archived_unix").NoAutoTime().Update(repo) | ||||
| 	_, err = db.GetEngine(ctx).ID(repo.ID).Cols("is_archived", "archived_unix").NoAutoTime().Update(repo) | ||||
| 	return err | ||||
| } | ||||
|   | ||||
| @@ -92,9 +92,9 @@ func SanitizeAndValidateTopics(topics []string) (validTopics, invalidTopics []st | ||||
| } | ||||
|  | ||||
| // GetTopicByName retrieves topic by name | ||||
| func GetTopicByName(name string) (*Topic, error) { | ||||
| func GetTopicByName(ctx context.Context, name string) (*Topic, error) { | ||||
| 	var topic Topic | ||||
| 	if has, err := db.GetEngine(db.DefaultContext).Where("name = ?", name).Get(&topic); err != nil { | ||||
| 	if has, err := db.GetEngine(ctx).Where("name = ?", name).Get(&topic); err != nil { | ||||
| 		return nil, err | ||||
| 	} else if !has { | ||||
| 		return nil, ErrTopicNotExist{name} | ||||
| @@ -192,8 +192,8 @@ func (opts *FindTopicOptions) toConds() builder.Cond { | ||||
| } | ||||
|  | ||||
| // FindTopics retrieves the topics via FindTopicOptions | ||||
| func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { | ||||
| 	sess := db.GetEngine(db.DefaultContext).Select("topic.*").Where(opts.toConds()) | ||||
| func FindTopics(ctx context.Context, opts *FindTopicOptions) ([]*Topic, int64, error) { | ||||
| 	sess := db.GetEngine(ctx).Select("topic.*").Where(opts.toConds()) | ||||
| 	orderBy := "topic.repo_count DESC" | ||||
| 	if opts.RepoID > 0 { | ||||
| 		sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") | ||||
| @@ -208,8 +208,8 @@ func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { | ||||
| } | ||||
|  | ||||
| // CountTopics counts the number of topics matching the FindTopicOptions | ||||
| func CountTopics(opts *FindTopicOptions) (int64, error) { | ||||
| 	sess := db.GetEngine(db.DefaultContext).Where(opts.toConds()) | ||||
| func CountTopics(ctx context.Context, opts *FindTopicOptions) (int64, error) { | ||||
| 	sess := db.GetEngine(ctx).Where(opts.toConds()) | ||||
| 	if opts.RepoID > 0 { | ||||
| 		sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") | ||||
| 	} | ||||
| @@ -231,8 +231,8 @@ func GetRepoTopicByName(ctx context.Context, repoID int64, topicName string) (*T | ||||
| } | ||||
|  | ||||
| // AddTopic adds a topic name to a repository (if it does not already have it) | ||||
| func AddTopic(repoID int64, topicName string) (*Topic, error) { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func AddTopic(ctx context.Context, repoID int64, topicName string) (*Topic, error) { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -261,8 +261,8 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) { | ||||
| } | ||||
|  | ||||
| // DeleteTopic removes a topic name from a repository (if it has it) | ||||
| func DeleteTopic(repoID int64, topicName string) (*Topic, error) { | ||||
| 	topic, err := GetRepoTopicByName(db.DefaultContext, repoID, topicName) | ||||
| func DeleteTopic(ctx context.Context, repoID int64, topicName string) (*Topic, error) { | ||||
| 	topic, err := GetRepoTopicByName(ctx, repoID, topicName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -271,26 +271,26 @@ func DeleteTopic(repoID int64, topicName string) (*Topic, error) { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	err = removeTopicFromRepo(db.DefaultContext, repoID, topic) | ||||
| 	err = removeTopicFromRepo(ctx, repoID, topic) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = syncTopicsInRepository(db.GetEngine(db.DefaultContext), repoID) | ||||
| 	err = syncTopicsInRepository(db.GetEngine(ctx), repoID) | ||||
|  | ||||
| 	return topic, err | ||||
| } | ||||
|  | ||||
| // SaveTopics save topics to a repository | ||||
| func SaveTopics(repoID int64, topicNames ...string) error { | ||||
| 	topics, _, err := FindTopics(&FindTopicOptions{ | ||||
| func SaveTopics(ctx context.Context, repoID int64, topicNames ...string) error { | ||||
| 	topics, _, err := FindTopics(ctx, &FindTopicOptions{ | ||||
| 		RepoID: repoID, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -19,47 +19,47 @@ func TestAddTopic(t *testing.T) { | ||||
|  | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	topics, _, err := repo_model.FindTopics(&repo_model.FindTopicOptions{}) | ||||
| 	topics, _, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, topics, totalNrOfTopics) | ||||
|  | ||||
| 	topics, total, err := repo_model.FindTopics(&repo_model.FindTopicOptions{ | ||||
| 	topics, total, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ | ||||
| 		ListOptions: db.ListOptions{Page: 1, PageSize: 2}, | ||||
| 	}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, topics, 2) | ||||
| 	assert.EqualValues(t, 6, total) | ||||
|  | ||||
| 	topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ | ||||
| 	topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ | ||||
| 		RepoID: 1, | ||||
| 	}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, topics, repo1NrOfTopics) | ||||
|  | ||||
| 	assert.NoError(t, repo_model.SaveTopics(2, "golang")) | ||||
| 	assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang")) | ||||
| 	repo2NrOfTopics := 1 | ||||
| 	topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{}) | ||||
| 	topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, topics, totalNrOfTopics) | ||||
|  | ||||
| 	topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ | ||||
| 	topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ | ||||
| 		RepoID: 2, | ||||
| 	}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, topics, repo2NrOfTopics) | ||||
|  | ||||
| 	assert.NoError(t, repo_model.SaveTopics(2, "golang", "gitea")) | ||||
| 	assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang", "gitea")) | ||||
| 	repo2NrOfTopics = 2 | ||||
| 	totalNrOfTopics++ | ||||
| 	topic, err := repo_model.GetTopicByName("gitea") | ||||
| 	topic, err := repo_model.GetTopicByName(db.DefaultContext, "gitea") | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.EqualValues(t, 1, topic.RepoCount) | ||||
|  | ||||
| 	topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{}) | ||||
| 	topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Len(t, topics, totalNrOfTopics) | ||||
|  | ||||
| 	topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ | ||||
| 	topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ | ||||
| 		RepoID: 2, | ||||
| 	}) | ||||
| 	assert.NoError(t, err) | ||||
|   | ||||
| @@ -16,11 +16,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| // UpdateRepositoryOwnerNames updates repository owner_names (this should only be used when the ownerName has changed case) | ||||
| func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { | ||||
| func UpdateRepositoryOwnerNames(ctx context.Context, ownerID int64, ownerName string) error { | ||||
| 	if ownerID == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -36,8 +36,8 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { | ||||
| } | ||||
|  | ||||
| // UpdateRepositoryUpdatedTime updates a repository's updated time | ||||
| func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error { | ||||
| 	_, err := db.GetEngine(db.DefaultContext).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) | ||||
| func UpdateRepositoryUpdatedTime(ctx context.Context, repoID int64, updateTime time.Time) error { | ||||
| 	_, err := db.GetEngine(ctx).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -107,7 +107,7 @@ func (err ErrRepoFilesAlreadyExist) Unwrap() error { | ||||
| } | ||||
|  | ||||
| // CheckCreateRepository check if could created a repository | ||||
| func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdopt bool) error { | ||||
| func CheckCreateRepository(ctx context.Context, doer, u *user_model.User, name string, overwriteOrAdopt bool) error { | ||||
| 	if !doer.CanCreateRepo() { | ||||
| 		return ErrReachLimitOfRepo{u.MaxRepoCreation} | ||||
| 	} | ||||
| @@ -116,7 +116,7 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	has, err := IsRepositoryModelOrDirExist(db.DefaultContext, u, name) | ||||
| 	has, err := IsRepositoryModelOrDirExist(ctx, u, name) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("IsRepositoryExist: %w", err) | ||||
| 	} else if has { | ||||
| @@ -136,18 +136,18 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo | ||||
| } | ||||
|  | ||||
| // ChangeRepositoryName changes all corresponding setting from old repository name to new one. | ||||
| func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName string) (err error) { | ||||
| func ChangeRepositoryName(ctx context.Context, doer *user_model.User, repo *Repository, newRepoName string) (err error) { | ||||
| 	oldRepoName := repo.Name | ||||
| 	newRepoName = strings.ToLower(newRepoName) | ||||
| 	if err = IsUsableRepoName(newRepoName); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := repo.LoadOwner(db.DefaultContext); err != nil { | ||||
| 	if err := repo.LoadOwner(ctx); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	has, err := IsRepositoryModelOrDirExist(db.DefaultContext, repo.Owner, newRepoName) | ||||
| 	has, err := IsRepositoryModelOrDirExist(ctx, repo.Owner, newRepoName) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("IsRepositoryExist: %w", err) | ||||
| 	} else if has { | ||||
| @@ -171,7 +171,7 @@ func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName s | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -79,8 +79,8 @@ func (r *RepoTransfer) LoadAttributes(ctx context.Context) error { | ||||
| // CanUserAcceptTransfer checks if the user has the rights to accept/decline a repo transfer. | ||||
| // For user, it checks if it's himself | ||||
| // For organizations, it checks if the user is able to create repos | ||||
| func (r *RepoTransfer) CanUserAcceptTransfer(u *user_model.User) bool { | ||||
| 	if err := r.LoadAttributes(db.DefaultContext); err != nil { | ||||
| func (r *RepoTransfer) CanUserAcceptTransfer(ctx context.Context, u *user_model.User) bool { | ||||
| 	if err := r.LoadAttributes(ctx); err != nil { | ||||
| 		log.Error("LoadAttributes: %v", err) | ||||
| 		return false | ||||
| 	} | ||||
| @@ -89,7 +89,7 @@ func (r *RepoTransfer) CanUserAcceptTransfer(u *user_model.User) bool { | ||||
| 		return r.RecipientID == u.ID | ||||
| 	} | ||||
|  | ||||
| 	allowed, err := organization.CanCreateOrgRepo(db.DefaultContext, r.RecipientID, u.ID) | ||||
| 	allowed, err := organization.CanCreateOrgRepo(ctx, r.RecipientID, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Error("CanCreateOrgRepo: %v", err) | ||||
| 		return false | ||||
| @@ -122,8 +122,8 @@ func deleteRepositoryTransfer(ctx context.Context, repoID int64) error { | ||||
|  | ||||
| // CancelRepositoryTransfer marks the repository as ready and remove pending transfer entry, | ||||
| // thus cancel the transfer process. | ||||
| func CancelRepositoryTransfer(repo *repo_model.Repository) error { | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| func CancelRepositoryTransfer(ctx context.Context, repo *repo_model.Repository) error { | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -199,7 +199,7 @@ func CreatePendingRepositoryTransfer(ctx context.Context, doer, newOwner *user_m | ||||
| } | ||||
|  | ||||
| // TransferOwnership transfers all corresponding repository items from old user to new one. | ||||
| func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_model.Repository) (err error) { | ||||
| func TransferOwnership(ctx context.Context, doer *user_model.User, newOwnerName string, repo *repo_model.Repository) (err error) { | ||||
| 	repoRenamed := false | ||||
| 	wikiRenamed := false | ||||
| 	oldOwnerName := doer.Name | ||||
| @@ -234,7 +234,7 @@ func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_mo | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -25,7 +25,7 @@ func TestRepositoryTransfer(t *testing.T) { | ||||
| 	assert.NotNil(t, transfer) | ||||
|  | ||||
| 	// Cancel transfer | ||||
| 	assert.NoError(t, CancelRepositoryTransfer(repo)) | ||||
| 	assert.NoError(t, CancelRepositoryTransfer(db.DefaultContext, repo)) | ||||
|  | ||||
| 	transfer, err = GetPendingRepositoryTransfer(db.DefaultContext, repo) | ||||
| 	assert.Error(t, err) | ||||
| @@ -53,5 +53,5 @@ func TestRepositoryTransfer(t *testing.T) { | ||||
| 	assert.Error(t, err) | ||||
|  | ||||
| 	// Cancel transfer | ||||
| 	assert.NoError(t, CancelRepositoryTransfer(repo)) | ||||
| 	assert.NoError(t, CancelRepositoryTransfer(db.DefaultContext, repo)) | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,8 @@ | ||||
| package user | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"code.gitea.io/gitea/models/db" | ||||
| 	"code.gitea.io/gitea/modules/timeutil" | ||||
| ) | ||||
| @@ -21,18 +23,18 @@ func init() { | ||||
| } | ||||
|  | ||||
| // IsFollowing returns true if user is following followID. | ||||
| func IsFollowing(userID, followID int64) bool { | ||||
| 	has, _ := db.GetEngine(db.DefaultContext).Get(&Follow{UserID: userID, FollowID: followID}) | ||||
| func IsFollowing(ctx context.Context, userID, followID int64) bool { | ||||
| 	has, _ := db.GetEngine(ctx).Get(&Follow{UserID: userID, FollowID: followID}) | ||||
| 	return has | ||||
| } | ||||
|  | ||||
| // FollowUser marks someone be another's follower. | ||||
| func FollowUser(userID, followID int64) (err error) { | ||||
| 	if userID == followID || IsFollowing(userID, followID) { | ||||
| func FollowUser(ctx context.Context, userID, followID int64) (err error) { | ||||
| 	if userID == followID || IsFollowing(ctx, userID, followID) { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -53,12 +55,12 @@ func FollowUser(userID, followID int64) (err error) { | ||||
| } | ||||
|  | ||||
| // UnfollowUser unmarks someone as another's follower. | ||||
| func UnfollowUser(userID, followID int64) (err error) { | ||||
| 	if userID == followID || !IsFollowing(userID, followID) { | ||||
| func UnfollowUser(ctx context.Context, userID, followID int64) (err error) { | ||||
| 	if userID == followID || !IsFollowing(ctx, userID, followID) { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ctx, committer, err := db.TxContext(db.DefaultContext) | ||||
| 	ctx, committer, err := db.TxContext(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -6,6 +6,7 @@ package user_test | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"code.gitea.io/gitea/models/db" | ||||
| 	"code.gitea.io/gitea/models/unittest" | ||||
| 	user_model "code.gitea.io/gitea/models/user" | ||||
|  | ||||
| @@ -14,9 +15,9 @@ import ( | ||||
|  | ||||
| func TestIsFollowing(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
| 	assert.True(t, user_model.IsFollowing(4, 2)) | ||||
| 	assert.False(t, user_model.IsFollowing(2, 4)) | ||||
| 	assert.False(t, user_model.IsFollowing(5, unittest.NonexistentID)) | ||||
| 	assert.False(t, user_model.IsFollowing(unittest.NonexistentID, 5)) | ||||
| 	assert.False(t, user_model.IsFollowing(unittest.NonexistentID, unittest.NonexistentID)) | ||||
| 	assert.True(t, user_model.IsFollowing(db.DefaultContext, 4, 2)) | ||||
| 	assert.False(t, user_model.IsFollowing(db.DefaultContext, 2, 4)) | ||||
| 	assert.False(t, user_model.IsFollowing(db.DefaultContext, 5, unittest.NonexistentID)) | ||||
| 	assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, 5)) | ||||
| 	assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) | ||||
| } | ||||
|   | ||||
| @@ -1246,7 +1246,7 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { | ||||
| 		} | ||||
|  | ||||
| 		// If they follow - they see each over | ||||
| 		follower := IsFollowing(u.ID, viewer.ID) | ||||
| 		follower := IsFollowing(ctx, u.ID, viewer.ID) | ||||
| 		if follower { | ||||
| 			return true | ||||
| 		} | ||||
|   | ||||
| @@ -449,13 +449,13 @@ func TestFollowUser(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	testSuccess := func(followerID, followedID int64) { | ||||
| 		assert.NoError(t, user_model.FollowUser(followerID, followedID)) | ||||
| 		assert.NoError(t, user_model.FollowUser(db.DefaultContext, followerID, followedID)) | ||||
| 		unittest.AssertExistsAndLoadBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) | ||||
| 	} | ||||
| 	testSuccess(4, 2) | ||||
| 	testSuccess(5, 2) | ||||
|  | ||||
| 	assert.NoError(t, user_model.FollowUser(2, 2)) | ||||
| 	assert.NoError(t, user_model.FollowUser(db.DefaultContext, 2, 2)) | ||||
|  | ||||
| 	unittest.CheckConsistencyFor(t, &user_model.User{}) | ||||
| } | ||||
| @@ -464,7 +464,7 @@ func TestUnfollowUser(t *testing.T) { | ||||
| 	assert.NoError(t, unittest.PrepareTestDatabase()) | ||||
|  | ||||
| 	testSuccess := func(followerID, followedID int64) { | ||||
| 		assert.NoError(t, user_model.UnfollowUser(followerID, followedID)) | ||||
| 		assert.NoError(t, user_model.UnfollowUser(db.DefaultContext, followerID, followedID)) | ||||
| 		unittest.AssertNotExistsBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) | ||||
| 	} | ||||
| 	testSuccess(4, 2) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user