mirror of
https://github.com/go-gitea/gitea
synced 2025-07-23 18:58:38 +00:00
68
modules/session/mem.go
Normal file
68
modules/session/mem.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
// Copyright 2025 The Gitea Authors. All rights reserved.
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/gob"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"gitea.com/go-chi/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockMemRawStore struct {
|
||||||
|
s *session.MemStore
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ session.RawStore = (*mockMemRawStore)(nil)
|
||||||
|
|
||||||
|
func (m *mockMemRawStore) Set(k, v any) error {
|
||||||
|
// We need to use gob to encode the value, to make it have the same behavior as other stores and catch abuses.
|
||||||
|
// Because gob needs to "Register" the type before it can encode it, and it's unable to decode a struct to "any" so use a map to help to decode the value.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := gob.NewEncoder(&buf).Encode(map[string]any{"v": v}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return m.s.Set(k, buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMemRawStore) Get(k any) (ret any) {
|
||||||
|
v, ok := m.s.Get(k).([]byte)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var w map[string]any
|
||||||
|
_ = gob.NewDecoder(bytes.NewBuffer(v)).Decode(&w)
|
||||||
|
return w["v"]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMemRawStore) Delete(k any) error {
|
||||||
|
return m.s.Delete(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMemRawStore) ID() string {
|
||||||
|
return m.s.ID()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMemRawStore) Release() error {
|
||||||
|
return m.s.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMemRawStore) Flush() error {
|
||||||
|
return m.s.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockMemStore struct {
|
||||||
|
*mockMemRawStore
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Store = (*mockMemStore)(nil)
|
||||||
|
|
||||||
|
func (m mockMemStore) Destroy(writer http.ResponseWriter, request *http.Request) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockMemStore(sid string) Store {
|
||||||
|
return &mockMemStore{&mockMemRawStore{session.NewMemStore(sid)}}
|
||||||
|
}
|
@@ -1,26 +0,0 @@
|
|||||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"gitea.com/go-chi/session"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockStore struct {
|
|
||||||
*session.MemStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockStore) Destroy(writer http.ResponseWriter, request *http.Request) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockStoreContextKeyStruct struct{}
|
|
||||||
|
|
||||||
var MockStoreContextKey = mockStoreContextKeyStruct{}
|
|
||||||
|
|
||||||
func NewMockStore(sid string) *MockStore {
|
|
||||||
return &MockStore{session.NewMemStore(sid)}
|
|
||||||
}
|
|
@@ -11,25 +11,25 @@ import (
|
|||||||
"gitea.com/go-chi/session"
|
"gitea.com/go-chi/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store represents a session store
|
type RawStore = session.RawStore
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
Get(any) any
|
RawStore
|
||||||
Set(any, any) error
|
|
||||||
Delete(any) error
|
|
||||||
ID() string
|
|
||||||
Release() error
|
|
||||||
Flush() error
|
|
||||||
Destroy(http.ResponseWriter, *http.Request) error
|
Destroy(http.ResponseWriter, *http.Request) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockStoreContextKeyStruct struct{}
|
||||||
|
|
||||||
|
var MockStoreContextKey = mockStoreContextKeyStruct{}
|
||||||
|
|
||||||
// RegenerateSession regenerates the underlying session and returns the new store
|
// RegenerateSession regenerates the underlying session and returns the new store
|
||||||
func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) {
|
func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) {
|
||||||
for _, f := range BeforeRegenerateSession {
|
for _, f := range BeforeRegenerateSession {
|
||||||
f(resp, req)
|
f(resp, req)
|
||||||
}
|
}
|
||||||
if setting.IsInTesting {
|
if setting.IsInTesting {
|
||||||
if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok {
|
if store := req.Context().Value(MockStoreContextKey); store != nil {
|
||||||
return store, nil
|
return store.(Store), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return session.RegenerateSession(resp, req)
|
return session.RegenerateSession(resp, req)
|
||||||
@@ -37,8 +37,8 @@ func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, erro
|
|||||||
|
|
||||||
func GetContextSession(req *http.Request) Store {
|
func GetContextSession(req *http.Request) Store {
|
||||||
if setting.IsInTesting {
|
if setting.IsInTesting {
|
||||||
if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok {
|
if store := req.Context().Value(MockStoreContextKey); store != nil {
|
||||||
return store
|
return store.(Store)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return session.GetSession(req)
|
return session.GetSession(req)
|
||||||
|
@@ -22,8 +22,8 @@ type VirtualSessionProvider struct {
|
|||||||
provider session.Provider
|
provider session.Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the cookie session provider with given root path.
|
// Init initializes the cookie session provider with the given config.
|
||||||
func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
|
func (o *VirtualSessionProvider) Init(gcLifetime int64, config string) error {
|
||||||
var opts session.Options
|
var opts session.Options
|
||||||
if err := json.Unmarshal([]byte(config), &opts); err != nil {
|
if err := json.Unmarshal([]byte(config), &opts); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -52,7 +52,7 @@ func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider)
|
return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider)
|
||||||
}
|
}
|
||||||
return o.provider.Init(gclifetime, opts.ProviderConfig)
|
return o.provider.Init(gcLifetime, opts.ProviderConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read returns raw session store by session ID.
|
// Read returns raw session store by session ID.
|
||||||
|
@@ -565,7 +565,7 @@ func createUserInContext(ctx *context.Context, tpl templates.TplName, form any,
|
|||||||
oauth2LinkAccount(ctx, user, possibleLinkAccountData, true)
|
oauth2LinkAccount(ctx, user, possibleLinkAccountData, true)
|
||||||
return false // user is already created here, all redirects are handled
|
return false // user is already created here, all redirects are handled
|
||||||
case setting.OAuth2AccountLinkingLogin:
|
case setting.OAuth2AccountLinkingLogin:
|
||||||
showLinkingLogin(ctx, &possibleLinkAccountData.AuthSource, possibleLinkAccountData.GothUser)
|
showLinkingLogin(ctx, possibleLinkAccountData.AuthSourceID, possibleLinkAccountData.GothUser)
|
||||||
return false // user will be created only after linking login
|
return false // user will be created only after linking login
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -633,7 +633,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, possibleLinkAcc
|
|||||||
|
|
||||||
// update external user information
|
// update external user information
|
||||||
if possibleLinkAccountData != nil {
|
if possibleLinkAccountData != nil {
|
||||||
if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSource.ID, u, possibleLinkAccountData.GothUser); err != nil {
|
if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSourceID, u, possibleLinkAccountData.GothUser); err != nil {
|
||||||
log.Error("EnsureLinkExternalToUser failed: %v", err)
|
log.Error("EnsureLinkExternalToUser failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -64,13 +64,14 @@ func TestUserLogin(t *testing.T) {
|
|||||||
func TestSignUpOAuth2Login(t *testing.T) {
|
func TestSignUpOAuth2Login(t *testing.T) {
|
||||||
defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)()
|
defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)()
|
||||||
|
|
||||||
|
_ = oauth2.Init(t.Context())
|
||||||
addOAuth2Source(t, "dummy-auth-source", oauth2.Source{})
|
addOAuth2Source(t, "dummy-auth-source", oauth2.Source{})
|
||||||
|
|
||||||
t.Run("OAuth2MissingField", func(t *testing.T) {
|
t.Run("OAuth2MissingField", func(t *testing.T) {
|
||||||
defer test.MockVariableValue(&gothic.CompleteUserAuth, func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
|
defer test.MockVariableValue(&gothic.CompleteUserAuth, func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
|
||||||
return goth.User{Provider: "dummy-auth-source", UserID: "dummy-user"}, nil
|
return goth.User{Provider: "dummy-auth-source", UserID: "dummy-user"}, nil
|
||||||
})()
|
})()
|
||||||
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")}
|
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")}
|
||||||
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback?code=dummy-code", mockOpt)
|
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback?code=dummy-code", mockOpt)
|
||||||
ctx.SetPathParam("provider", "dummy-auth-source")
|
ctx.SetPathParam("provider", "dummy-auth-source")
|
||||||
SignInOAuthCallback(ctx)
|
SignInOAuthCallback(ctx)
|
||||||
@@ -84,7 +85,7 @@ func TestSignUpOAuth2Login(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OAuth2CallbackError", func(t *testing.T) {
|
t.Run("OAuth2CallbackError", func(t *testing.T) {
|
||||||
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")}
|
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")}
|
||||||
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback", mockOpt)
|
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback", mockOpt)
|
||||||
ctx.SetPathParam("provider", "dummy-auth-source")
|
ctx.SetPathParam("provider", "dummy-auth-source")
|
||||||
SignInOAuthCallback(ctx)
|
SignInOAuthCallback(ctx)
|
||||||
|
@@ -170,7 +170,7 @@ func LinkAccountPostSignIn(ctx *context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData *LinkAccountData, remember bool) {
|
func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData *LinkAccountData, remember bool) {
|
||||||
oauth2SignInSync(ctx, &linkAccountData.AuthSource, u, linkAccountData.GothUser)
|
oauth2SignInSync(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
|
||||||
if ctx.Written() {
|
if ctx.Written() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -185,7 +185,7 @@ func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, u, linkAccountData.GothUser)
|
err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.ServerError("UserLinkAccount", err)
|
ctx.ServerError("UserLinkAccount", err)
|
||||||
return
|
return
|
||||||
@@ -295,7 +295,7 @@ func LinkAccountPostRegister(ctx *context.Context) {
|
|||||||
Email: form.Email,
|
Email: form.Email,
|
||||||
Passwd: form.Password,
|
Passwd: form.Password,
|
||||||
LoginType: auth.OAuth2,
|
LoginType: auth.OAuth2,
|
||||||
LoginSource: linkAccountData.AuthSource.ID,
|
LoginSource: linkAccountData.AuthSourceID,
|
||||||
LoginName: linkAccountData.GothUser.UserID,
|
LoginName: linkAccountData.GothUser.UserID,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,7 +304,12 @@ func LinkAccountPostRegister(ctx *context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
source := linkAccountData.AuthSource.Cfg.(*oauth2.Source)
|
authSource, err := auth.GetSourceByID(ctx, linkAccountData.AuthSourceID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.ServerError("GetSourceByID", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
source := authSource.Cfg.(*oauth2.Source)
|
||||||
if err := syncGroupsToTeams(ctx, source, &linkAccountData.GothUser, u); err != nil {
|
if err := syncGroupsToTeams(ctx, source, &linkAccountData.GothUser, u); err != nil {
|
||||||
ctx.ServerError("SyncGroupsToTeams", err)
|
ctx.ServerError("SyncGroupsToTeams", err)
|
||||||
return
|
return
|
||||||
@@ -318,5 +323,5 @@ func linkAccountFromContext(ctx *context.Context, user *user_model.User) error {
|
|||||||
if linkAccountData == nil {
|
if linkAccountData == nil {
|
||||||
return errors.New("not in LinkAccount session")
|
return errors.New("not in LinkAccount session")
|
||||||
}
|
}
|
||||||
return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, user, linkAccountData.GothUser)
|
return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, user, linkAccountData.GothUser)
|
||||||
}
|
}
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/gob"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html"
|
"html"
|
||||||
@@ -171,7 +172,7 @@ func SignInOAuthCallback(ctx *context.Context) {
|
|||||||
gothUser.RawData = make(map[string]any)
|
gothUser.RawData = make(map[string]any)
|
||||||
}
|
}
|
||||||
gothUser.RawData["__giteaAutoRegMissingFields"] = missingFields
|
gothUser.RawData["__giteaAutoRegMissingFields"] = missingFields
|
||||||
showLinkingLogin(ctx, authSource, gothUser)
|
showLinkingLogin(ctx, authSource.ID, gothUser)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u = &user_model.User{
|
u = &user_model.User{
|
||||||
@@ -192,7 +193,7 @@ func SignInOAuthCallback(ctx *context.Context) {
|
|||||||
u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue
|
u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue
|
||||||
u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted)
|
u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted)
|
||||||
|
|
||||||
linkAccountData := &LinkAccountData{*authSource, gothUser}
|
linkAccountData := &LinkAccountData{authSource.ID, gothUser}
|
||||||
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled {
|
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled {
|
||||||
linkAccountData = nil
|
linkAccountData = nil
|
||||||
}
|
}
|
||||||
@@ -207,7 +208,7 @@ func SignInOAuthCallback(ctx *context.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// no existing user is found, request attach or new account
|
// no existing user is found, request attach or new account
|
||||||
showLinkingLogin(ctx, authSource, gothUser)
|
showLinkingLogin(ctx, authSource.ID, gothUser)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -272,11 +273,12 @@ func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *g
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LinkAccountData struct {
|
type LinkAccountData struct {
|
||||||
AuthSource auth.Source
|
AuthSourceID int64
|
||||||
GothUser goth.User
|
GothUser goth.User
|
||||||
}
|
}
|
||||||
|
|
||||||
func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
|
func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
|
||||||
|
gob.Register(LinkAccountData{})
|
||||||
v, ok := ctx.Session.Get("linkAccountData").(LinkAccountData)
|
v, ok := ctx.Session.Get("linkAccountData").(LinkAccountData)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -284,11 +286,16 @@ func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
|
|||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
func showLinkingLogin(ctx *context.Context, authSource *auth.Source, gothUser goth.User) {
|
func Oauth2SetLinkAccountData(ctx *context.Context, linkAccountData LinkAccountData) error {
|
||||||
if err := updateSession(ctx, nil, map[string]any{
|
gob.Register(LinkAccountData{})
|
||||||
"linkAccountData": LinkAccountData{*authSource, gothUser},
|
return updateSession(ctx, nil, map[string]any{
|
||||||
}); err != nil {
|
"linkAccountData": linkAccountData,
|
||||||
ctx.ServerError("updateSession", err)
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func showLinkingLogin(ctx *context.Context, authSourceID int64, gothUser goth.User) {
|
||||||
|
if err := Oauth2SetLinkAccountData(ctx, LinkAccountData{authSourceID, gothUser}); err != nil {
|
||||||
|
ctx.ServerError("Oauth2SetLinkAccountData", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx.Redirect(setting.AppSubURL + "/user/link_account")
|
ctx.Redirect(setting.AppSubURL + "/user/link_account")
|
||||||
@@ -313,7 +320,7 @@ func oauth2UpdateAvatarIfNeed(ctx *context.Context, url string, u *user_model.Us
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
|
func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
|
||||||
oauth2SignInSync(ctx, authSource, u, gothUser)
|
oauth2SignInSync(ctx, authSource.ID, u, gothUser)
|
||||||
if ctx.Written() {
|
if ctx.Written() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -18,9 +18,14 @@ import (
|
|||||||
"github.com/markbates/goth"
|
"github.com/markbates/goth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
|
func oauth2SignInSync(ctx *context.Context, authSourceID int64, u *user_model.User, gothUser goth.User) {
|
||||||
oauth2UpdateAvatarIfNeed(ctx, gothUser.AvatarURL, u)
|
oauth2UpdateAvatarIfNeed(ctx, gothUser.AvatarURL, u)
|
||||||
|
|
||||||
|
authSource, err := auth.GetSourceByID(ctx, authSourceID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.ServerError("GetSourceByID", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
oauth2Source, _ := authSource.Cfg.(*oauth2.Source)
|
oauth2Source, _ := authSource.Cfg.(*oauth2.Source)
|
||||||
if !authSource.IsOAuth2() || oauth2Source == nil {
|
if !authSource.IsOAuth2() || oauth2Source == nil {
|
||||||
ctx.ServerError("oauth2SignInSync", fmt.Errorf("source %s is not an OAuth2 source", gothUser.Provider))
|
ctx.ServerError("oauth2SignInSync", fmt.Errorf("source %s is not an OAuth2 source", gothUser.Provider))
|
||||||
@@ -45,7 +50,7 @@ func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_mod
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
|
err = oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err)
|
log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err)
|
||||||
}
|
}
|
||||||
|
@@ -11,7 +11,6 @@ import (
|
|||||||
"code.gitea.io/gitea/modules/log"
|
"code.gitea.io/gitea/modules/log"
|
||||||
session_module "code.gitea.io/gitea/modules/session"
|
session_module "code.gitea.io/gitea/modules/session"
|
||||||
|
|
||||||
chiSession "gitea.com/go-chi/session"
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,11 +34,11 @@ func (st *SessionsStore) New(r *http.Request, name string) (*sessions.Session, e
|
|||||||
|
|
||||||
// getOrNew gets the session from the chi-session if it exists. Override permits the overriding of an unexpected object.
|
// getOrNew gets the session from the chi-session if it exists. Override permits the overriding of an unexpected object.
|
||||||
func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (*sessions.Session, error) {
|
func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (*sessions.Session, error) {
|
||||||
chiStore := chiSession.GetSession(r)
|
store := session_module.GetContextSession(r)
|
||||||
|
|
||||||
session := sessions.NewSession(st, name)
|
session := sessions.NewSession(st, name)
|
||||||
|
|
||||||
rawData := chiStore.Get(name)
|
rawData := store.Get(name)
|
||||||
if rawData != nil {
|
if rawData != nil {
|
||||||
oldSession, ok := rawData.(*sessions.Session)
|
oldSession, ok := rawData.(*sessions.Session)
|
||||||
if ok {
|
if ok {
|
||||||
@@ -56,21 +55,21 @@ func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
session.IsNew = override
|
session.IsNew = override
|
||||||
session.ID = chiStore.ID() // Simply copy the session id from the chi store
|
session.ID = store.ID() // Simply copy the session id from the chi store
|
||||||
|
|
||||||
return session, chiStore.Set(name, session)
|
return session, store.Set(name, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save should persist session to the underlying store implementation.
|
// Save should persist session to the underlying store implementation.
|
||||||
func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
||||||
chiStore := chiSession.GetSession(r)
|
store := session_module.GetContextSession(r)
|
||||||
|
|
||||||
if session.IsNew {
|
if session.IsNew {
|
||||||
_, _ = session_module.RegenerateSession(w, r)
|
_, _ = session_module.RegenerateSession(w, r)
|
||||||
session.IsNew = false
|
session.IsNew = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := chiStore.Set(session.Name(), session); err != nil {
|
if err := store.Set(session.Name(), session); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +82,7 @@ func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return chiStore.Release()
|
return store.Release()
|
||||||
}
|
}
|
||||||
|
|
||||||
type sizeWriter struct {
|
type sizeWriter struct {
|
||||||
|
@@ -49,7 +49,7 @@ func mockRequest(t *testing.T, reqPath string) *http.Request {
|
|||||||
|
|
||||||
type MockContextOption struct {
|
type MockContextOption struct {
|
||||||
Render context.Render
|
Render context.Render
|
||||||
SessionStore *session.MockStore
|
SessionStore session.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockContext mock context for unit tests
|
// MockContext mock context for unit tests
|
||||||
|
@@ -107,7 +107,7 @@ func TestEnablePasswordSignInFormAndEnablePasskeyAuth(t *testing.T) {
|
|||||||
mockLinkAccount := func(ctx *context.Context) {
|
mockLinkAccount := func(ctx *context.Context) {
|
||||||
authSource := auth_model.Source{ID: 1}
|
authSource := auth_model.Source{ID: 1}
|
||||||
gothUser := goth.User{Email: "invalid-email", Name: "."}
|
gothUser := goth.User{Email: "invalid-email", Name: "."}
|
||||||
_ = ctx.Session.Set("linkAccountData", auth.LinkAccountData{AuthSource: authSource, GothUser: gothUser})
|
_ = auth.Oauth2SetLinkAccountData(ctx, auth.LinkAccountData{AuthSourceID: authSource.ID, GothUser: gothUser})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("EnablePasswordSignInForm=false", func(t *testing.T) {
|
t.Run("EnablePasswordSignInForm=false", func(t *testing.T) {
|
||||||
|
Reference in New Issue
Block a user