From 45c15387b292c25b5d0572b2eb3f85414156372a Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Fri, 16 Feb 2024 23:18:30 +0800 Subject: [PATCH] Refactor JWT secret generating & decoding code (#29172) Old code is not consistent for generating & decoding the JWT secrets. Now, the callers only need to use 2 consistent functions: NewJwtSecretWithBase64 and DecodeJwtSecretBase64 And remove a non-common function Base64FixedDecode from util.go --- cmd/generate.go | 2 +- modules/generate/generate.go | 24 ++++++++------ modules/generate/generate_test.go | 34 ++++++++++++++++++++ modules/setting/lfs.go | 6 ++-- modules/setting/oauth2.go | 7 ++-- modules/util/util.go | 11 ------- modules/util/util_test.go | 14 -------- routers/install/install.go | 2 +- services/auth/source/oauth2/jwtsigningkey.go | 3 +- 9 files changed, 57 insertions(+), 46 deletions(-) create mode 100644 modules/generate/generate_test.go diff --git a/cmd/generate.go b/cmd/generate.go index 4ab10da22a..90b32ecaf0 100644 --- a/cmd/generate.go +++ b/cmd/generate.go @@ -70,7 +70,7 @@ func runGenerateInternalToken(c *cli.Context) error { } func runGenerateLfsJwtSecret(c *cli.Context) error { - _, jwtSecretBase64, err := generate.NewJwtSecretBase64() + _, jwtSecretBase64, err := generate.NewJwtSecretWithBase64() if err != nil { return err } diff --git a/modules/generate/generate.go b/modules/generate/generate.go index ee3c76059b..2d9a3dd902 100644 --- a/modules/generate/generate.go +++ b/modules/generate/generate.go @@ -7,6 +7,7 @@ package generate import ( "crypto/rand" "encoding/base64" + "fmt" "io" "time" @@ -38,19 +39,24 @@ func NewInternalToken() (string, error) { return internalToken, nil } -// NewJwtSecret generates a new value intended to be used for JWT secrets. -func NewJwtSecret() ([]byte, error) { - bytes := make([]byte, 32) - _, err := io.ReadFull(rand.Reader, bytes) - if err != nil { +const defaultJwtSecretLen = 32 + +// DecodeJwtSecretBase64 decodes a base64 encoded jwt secret into bytes, and check its length +func DecodeJwtSecretBase64(src string) ([]byte, error) { + encoding := base64.RawURLEncoding + decoded := make([]byte, encoding.DecodedLen(len(src))+3) + if n, err := encoding.Decode(decoded, []byte(src)); err != nil { return nil, err + } else if n != defaultJwtSecretLen { + return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, defaultJwtSecretLen) } - return bytes, nil + return decoded[:defaultJwtSecretLen], nil } -// NewJwtSecretBase64 generates a new base64 encoded value intended to be used for JWT secrets. -func NewJwtSecretBase64() ([]byte, string, error) { - bytes, err := NewJwtSecret() +// NewJwtSecretWithBase64 generates a jwt secret with its base64 encoded value intended to be used for saving into config file +func NewJwtSecretWithBase64() ([]byte, string, error) { + bytes := make([]byte, defaultJwtSecretLen) + _, err := io.ReadFull(rand.Reader, bytes) if err != nil { return nil, "", err } diff --git a/modules/generate/generate_test.go b/modules/generate/generate_test.go new file mode 100644 index 0000000000..af640a60c1 --- /dev/null +++ b/modules/generate/generate_test.go @@ -0,0 +1,34 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package generate + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDecodeJwtSecretBase64(t *testing.T) { + _, err := DecodeJwtSecretBase64("abcd") + assert.ErrorContains(t, err, "invalid base64 decoded length") + _, err = DecodeJwtSecretBase64(strings.Repeat("a", 64)) + assert.ErrorContains(t, err, "invalid base64 decoded length") + + str32 := strings.Repeat("x", 32) + encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32)) + decoded32, err := DecodeJwtSecretBase64(encoded32) + assert.NoError(t, err) + assert.Equal(t, str32, string(decoded32)) +} + +func TestNewJwtSecretWithBase64(t *testing.T) { + secret, encoded, err := NewJwtSecretWithBase64() + assert.NoError(t, err) + assert.Len(t, secret, 32) + decoded, err := DecodeJwtSecretBase64(encoded) + assert.NoError(t, err) + assert.Equal(t, secret, decoded) +} diff --git a/modules/setting/lfs.go b/modules/setting/lfs.go index a5ea537cef..22a75f6008 100644 --- a/modules/setting/lfs.go +++ b/modules/setting/lfs.go @@ -4,12 +4,10 @@ package setting import ( - "encoding/base64" "fmt" "time" "code.gitea.io/gitea/modules/generate" - "code.gitea.io/gitea/modules/util" ) // LFS represents the configuration for Git LFS @@ -62,9 +60,9 @@ func loadLFSFrom(rootCfg ConfigProvider) error { } LFS.JWTSecretBase64 = loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET") - LFS.JWTSecretBytes, err = util.Base64FixedDecode(base64.RawURLEncoding, []byte(LFS.JWTSecretBase64), 32) + LFS.JWTSecretBytes, err = generate.DecodeJwtSecretBase64(LFS.JWTSecretBase64) if err != nil { - LFS.JWTSecretBytes, LFS.JWTSecretBase64, err = generate.NewJwtSecretBase64() + LFS.JWTSecretBytes, LFS.JWTSecretBase64, err = generate.NewJwtSecretWithBase64() if err != nil { return fmt.Errorf("error generating JWT Secret for custom config: %v", err) } diff --git a/modules/setting/oauth2.go b/modules/setting/oauth2.go index 0d15e91ef0..e16e167024 100644 --- a/modules/setting/oauth2.go +++ b/modules/setting/oauth2.go @@ -4,13 +4,11 @@ package setting import ( - "encoding/base64" "math" "path/filepath" "code.gitea.io/gitea/modules/generate" "code.gitea.io/gitea/modules/log" - "code.gitea.io/gitea/modules/util" ) // OAuth2UsernameType is enum describing the way gitea 'name' should be generated from oauth2 data @@ -137,13 +135,12 @@ func loadOAuth2From(rootCfg ConfigProvider) { } if InstallLock { - if _, err := util.Base64FixedDecode(base64.RawURLEncoding, []byte(OAuth2.JWTSecretBase64), 32); err != nil { - key, err := generate.NewJwtSecret() + if _, err := generate.DecodeJwtSecretBase64(OAuth2.JWTSecretBase64); err != nil { + _, OAuth2.JWTSecretBase64, err = generate.NewJwtSecretWithBase64() if err != nil { log.Fatal("error generating JWT secret: %v", err) } - OAuth2.JWTSecretBase64 = base64.RawURLEncoding.EncodeToString(key) saveCfg, err := rootCfg.PrepareSaving() if err != nil { log.Fatal("save oauth2.JWT_SECRET failed: %v", err) diff --git a/modules/util/util.go b/modules/util/util.go index c47931f6c9..0e5c6a4e64 100644 --- a/modules/util/util.go +++ b/modules/util/util.go @@ -6,7 +6,6 @@ package util import ( "bytes" "crypto/rand" - "encoding/base64" "fmt" "math/big" "strconv" @@ -246,13 +245,3 @@ func ToFloat64(number any) (float64, error) { func ToPointer[T any](val T) *T { return &val } - -func Base64FixedDecode(encoding *base64.Encoding, src []byte, length int) ([]byte, error) { - decoded := make([]byte, encoding.DecodedLen(len(src))+3) - if n, err := encoding.Decode(decoded, src); err != nil { - return nil, err - } else if n != length { - return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, length) - } - return decoded[:length], nil -} diff --git a/modules/util/util_test.go b/modules/util/util_test.go index 8509d8aced..c5830ce01c 100644 --- a/modules/util/util_test.go +++ b/modules/util/util_test.go @@ -4,7 +4,6 @@ package util import ( - "encoding/base64" "regexp" "strings" "testing" @@ -234,16 +233,3 @@ func TestToPointer(t *testing.T) { val123 := 123 assert.False(t, &val123 == ToPointer(val123)) } - -func TestBase64FixedDecode(t *testing.T) { - _, err := Base64FixedDecode(base64.RawURLEncoding, []byte("abcd"), 32) - assert.ErrorContains(t, err, "invalid base64 decoded length") - _, err = Base64FixedDecode(base64.RawURLEncoding, []byte(strings.Repeat("a", 64)), 32) - assert.ErrorContains(t, err, "invalid base64 decoded length") - - str32 := strings.Repeat("x", 32) - encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32)) - decoded32, err := Base64FixedDecode(base64.RawURLEncoding, []byte(encoded32), 32) - assert.NoError(t, err) - assert.Equal(t, str32, string(decoded32)) -} diff --git a/routers/install/install.go b/routers/install/install.go index 5c0290d2cc..064575d34c 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -409,7 +409,7 @@ func SubmitInstall(ctx *context.Context) { cfg.Section("server").Key("LFS_START_SERVER").SetValue("true") cfg.Section("lfs").Key("PATH").SetValue(form.LFSRootPath) var lfsJwtSecret string - if _, lfsJwtSecret, err = generate.NewJwtSecretBase64(); err != nil { + if _, lfsJwtSecret, err = generate.NewJwtSecretWithBase64(); err != nil { ctx.RenderWithErr(ctx.Tr("install.lfs_jwt_secret_failed", err), tplInstall, &form) return } diff --git a/services/auth/source/oauth2/jwtsigningkey.go b/services/auth/source/oauth2/jwtsigningkey.go index eca0b8b7e1..2afe557b0d 100644 --- a/services/auth/source/oauth2/jwtsigningkey.go +++ b/services/auth/source/oauth2/jwtsigningkey.go @@ -18,6 +18,7 @@ import ( "path/filepath" "strings" + "code.gitea.io/gitea/modules/generate" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" @@ -336,7 +337,7 @@ func InitSigningKey() error { // loadSymmetricKey checks if the configured secret is valid. // If it is not valid, it will return an error. func loadSymmetricKey() (any, error) { - return util.Base64FixedDecode(base64.RawURLEncoding, []byte(setting.OAuth2.JWTSecretBase64), 32) + return generate.DecodeJwtSecretBase64(setting.OAuth2.JWTSecretBase64) } // loadOrCreateAsymmetricKey checks if the configured private key exists.