diff --git a/routers/web/user/oauth.go b/routers/web/user/oauth.go index e29826630a..cec6a92bbe 100644 --- a/routers/web/user/oauth.go +++ b/routers/web/user/oauth.go @@ -115,7 +115,7 @@ type AccessTokenResponse struct { IDToken string `json:"id_token,omitempty"` } -func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { +func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { if setting.OAuth2.InvalidateRefreshTokens { if err := grant.IncreaseCounter(); err != nil { return nil, &AccessTokenError{ @@ -133,7 +133,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign ExpiresAt: expirationDate.AsTime().Unix(), }, } - signedAccessToken, err := accessToken.SignToken() + signedAccessToken, err := accessToken.SignToken(serverKey) if err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -151,7 +151,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign ExpiresAt: refreshExpirationDate, }, } - signedRefreshToken, err := refreshToken.SignToken() + signedRefreshToken, err := refreshToken.SignToken(serverKey) if err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -207,7 +207,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign idToken.EmailVerified = user.IsActive } - signedIDToken, err = idToken.SignToken(signingKey) + signedIDToken, err = idToken.SignToken(clientKey) if err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -265,7 +265,7 @@ func IntrospectOAuth(ctx *context.Context) { } form := web.GetForm(ctx).(*forms.IntrospectTokenForm) - token, err := oauth2.ParseToken(form.Token) + token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey) if err == nil { if token.Valid() == nil { grant, err := models.GetOAuth2GrantByID(token.GrantID) @@ -544,9 +544,11 @@ func AccessTokenOAuth(ctx *context.Context) { } } - signingKey := oauth2.DefaultSigningKey - if signingKey.IsSymmetric() { - clientKey, err := oauth2.CreateJWTSigningKey(signingKey.SigningMethod().Alg(), []byte(form.ClientSecret)) + serverKey := oauth2.DefaultSigningKey + clientKey := serverKey + if serverKey.IsSymmetric() { + var err error + clientKey, err = oauth2.CreateJWTSigningKey(serverKey.SigningMethod().Alg(), []byte(form.ClientSecret)) if err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -554,14 +556,13 @@ func AccessTokenOAuth(ctx *context.Context) { }) return } - signingKey = clientKey } switch form.GrantType { case "refresh_token": - handleRefreshToken(ctx, form, signingKey) + handleRefreshToken(ctx, form, serverKey, clientKey) case "authorization_code": - handleAuthorizationCode(ctx, form, signingKey) + handleAuthorizationCode(ctx, form, serverKey, clientKey) default: handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeUnsupportedGrantType, @@ -570,8 +571,8 @@ func AccessTokenOAuth(ctx *context.Context) { } } -func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) { - token, err := oauth2.ParseToken(form.RefreshToken) +func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { + token, err := oauth2.ParseToken(form.RefreshToken, serverKey) if err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeUnauthorizedClient, @@ -598,7 +599,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID) return } - accessToken, tokenErr := newAccessTokenResponse(grant, signingKey) + accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey) if tokenErr != nil { handleAccessTokenError(ctx, *tokenErr) return @@ -606,7 +607,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin ctx.JSON(http.StatusOK, accessToken) } -func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) { +func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { app, err := models.GetOAuth2ApplicationByClientID(form.ClientID) if err != nil { handleAccessTokenError(ctx, AccessTokenError{ @@ -660,7 +661,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s ErrorDescription: "cannot proceed your request", }) } - resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, signingKey) + resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey) if tokenErr != nil { handleAccessTokenError(ctx, *tokenErr) return diff --git a/routers/web/user/oauth_test.go b/routers/web/user/oauth_test.go index c2f9ec87b5..40116d3c12 100644 --- a/routers/web/user/oauth_test.go +++ b/routers/web/user/oauth_test.go @@ -18,9 +18,8 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32)) assert.NoError(t, err) assert.NotNil(t, signingKey) - oauth2.DefaultSigningKey = signingKey - response, terr := newAccessTokenResponse(grant, signingKey) + response, terr := newAccessTokenResponse(grant, signingKey, signingKey) assert.Nil(t, terr) assert.NotNil(t, response) diff --git a/services/auth/oauth2.go b/services/auth/oauth2.go index f7f870dade..665e5232cc 100644 --- a/services/auth/oauth2.go +++ b/services/auth/oauth2.go @@ -29,9 +29,9 @@ func CheckOAuthAccessToken(accessToken string) int64 { if !strings.Contains(accessToken, ".") { return 0 } - token, err := oauth2.ParseToken(accessToken) + token, err := oauth2.ParseToken(accessToken, oauth2.DefaultSigningKey) if err != nil { - log.Trace("ParseOAuth2Token: %v", err) + log.Trace("oauth2.ParseToken: %v", err) return 0 } var grant *models.OAuth2Grant diff --git a/services/auth/source/oauth2/token.go b/services/auth/source/oauth2/token.go index 529e04577d..16d1220842 100644 --- a/services/auth/source/oauth2/token.go +++ b/services/auth/source/oauth2/token.go @@ -40,12 +40,12 @@ type Token struct { } // ParseToken parses a signed jwt string -func ParseToken(jwtToken string) (*Token, error) { +func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) { parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) { - if token.Method == nil || token.Method.Alg() != DefaultSigningKey.SigningMethod().Alg() { + if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() { return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"]) } - return DefaultSigningKey.VerifyKey(), nil + return signingKey.VerifyKey(), nil }) if err != nil { return nil, err @@ -59,11 +59,11 @@ func ParseToken(jwtToken string) (*Token, error) { } // SignToken signs the token with the JWT secret -func (token *Token) SignToken() (string, error) { +func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) { token.IssuedAt = time.Now().Unix() - jwtToken := jwt.NewWithClaims(DefaultSigningKey.SigningMethod(), token) - DefaultSigningKey.PreProcessToken(jwtToken) - return jwtToken.SignedString(DefaultSigningKey.SignKey()) + jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token) + signingKey.PreProcessToken(jwtToken) + return jwtToken.SignedString(signingKey.SignKey()) } // OIDCToken represents an OpenID Connect id_token