diff --git a/auth_jwt.go b/auth_jwt.go index 96d2cda..494e5bd 100644 --- a/auth_jwt.go +++ b/auth_jwt.go @@ -28,6 +28,7 @@ package jwt import ( "context" "crypto/rsa" + "encoding/json" "errors" "io/ioutil" "net/http" @@ -181,6 +182,9 @@ type HertzJWTMiddleware struct { // CookieSameSite allow use protocol.CookieSameSite cookie param CookieSameSite protocol.CookieSameSite + + // ParseOptions allow to modify jwt's parser methods + ParseOptions []jwt.ParserOption } var ( @@ -447,19 +451,27 @@ func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.Request return } - if claims["exp"] == nil { + switch v := claims["exp"].(type) { + case nil: mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, ctx, c)) return - } - - if _, ok := claims["exp"].(float64); !ok { - mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c)) - return - } - - if int64(claims["exp"].(float64)) < mw.TimeFunc().Unix() { - mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c)) - return + case float64: + if int64(v) < mw.TimeFunc().Unix() { + mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c)) + return + } + case json.Number: + n, err := v.Int64() + if err != nil { + mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c)) + return + } + if n < mw.TimeFunc().Unix() { + mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c)) + return + } + default: + mw.Unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c)) } c.Set("JWT_PAYLOAD", claims) @@ -728,7 +740,7 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont } if mw.KeyFunc != nil { - return jwt.Parse(token, mw.KeyFunc) + return jwt.Parse(token, mw.KeyFunc, mw.ParseOptions...) } return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { @@ -743,13 +755,13 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont c.Set("JWT_TOKEN", token) return mw.Key, nil - }) + }, mw.ParseOptions...) } // ParseTokenString parse jwt token string func (mw *HertzJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error) { if mw.KeyFunc != nil { - return jwt.Parse(token, mw.KeyFunc) + return jwt.Parse(token, mw.KeyFunc, mw.ParseOptions...) } return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { @@ -761,7 +773,7 @@ func (mw *HertzJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error) } return mw.Key, nil - }) + }, mw.ParseOptions...) } func (mw *HertzJWTMiddleware) unauthorized(ctx context.Context, c *app.RequestContext, code int, message string) { diff --git a/auth_jwt_test.go b/auth_jwt_test.go index af9c340..51c1476 100644 --- a/auth_jwt_test.go +++ b/auth_jwt_test.go @@ -28,6 +28,7 @@ package jwt import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io/ioutil" @@ -90,6 +91,28 @@ func makeTokenString(SigningAlgorithm, username string) string { return tokenString } +func makeTokenStringWithUserID(SigningAlgorithm string, userID int64) string { + if SigningAlgorithm == "" { + SigningAlgorithm = "HS256" + } + + token := jwt.New(jwt.GetSigningMethod(SigningAlgorithm)) + claims := token.Claims.(jwt.MapClaims) + claims["identity"] = userID + claims["exp"] = time.Now().Add(time.Hour).Unix() + claims["orig_iat"] = time.Now().Unix() + var tokenString string + if SigningAlgorithm == "RS256" { + keyData, _ := ioutil.ReadFile("testdata/jwtRS256.key") + signKey, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData) + tokenString, _ = token.SignedString(signKey) + } else { + tokenString, _ = token.SignedString(key) + } + + return tokenString +} + func keyFunc(token *jwt.Token) (interface{}, error) { cert, err := ioutil.ReadFile("testdata/jwtRS256.key.pub") if err != nil { @@ -533,6 +556,32 @@ func TestAuthorizator(t *testing.T) { assert.DeepEqual(t, http.StatusOK, w.Code) } +func TestParseTokenWithJsonNumber(t *testing.T) { + var userID int64 = 64 + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: time.Hour * 24, + IdentityHandler: func(ctx context.Context, c *app.RequestContext) interface{} { + claims := ExtractClaims(ctx, c) + testNum, err := claims["identity"].(json.Number).Int64() + assert.Nil(t, err) + assert.DeepEqual(t, userID, testNum) + return testNum + }, + Unauthorized: func(ctx context.Context, c *app.RequestContext, code int, message string) { + c.String(code, message) + }, + ParseOptions: []jwt.ParserOption{jwt.WithJSONNumber()}, + }) + + handler := hertzHandler(authMiddleware) + + w := ut.PerformRequest(handler, http.MethodGet, "/auth/hello", nil, ut.Header{Key: "Authorization", Value: "Bearer " + makeTokenStringWithUserID("HS256", userID)}) + assert.DeepEqual(t, http.StatusOK, w.Code) +} + func TestClaimsDuringAuthorization(t *testing.T) { // the middleware to test authMiddleware, _ := New(&HertzJWTMiddleware{