Skip to content

Commit

Permalink
feat: add ParseOptions field for HertzJWTMiddleware (#13)
Browse files Browse the repository at this point in the history
* feat: add ParseOptions field for HertzJWTMiddleware

* improve ut
  • Loading branch information
justlorain authored Dec 25, 2022
1 parent 293c661 commit 08f5053
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
42 changes: 27 additions & 15 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package jwt
import (
"context"
"crypto/rsa"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -181,6 +182,9 @@ type HertzJWTMiddleware struct {

// CookieSameSite allow use protocol.CookieSameSite cookie param
CookieSameSite protocol.CookieSameSite

// ParseOptions allow to modify jwt's parser methods
ParseOptions []jwt.ParserOption
}

var (
Expand Down Expand Up @@ -447,19 +451,27 @@ func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.Request
return
}

if claims["exp"] == nil {
switch v := claims["exp"].(type) {
case nil:
mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, ctx, c))
return
}

if _, ok := claims["exp"].(float64); !ok {
mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c))
return
}

if int64(claims["exp"].(float64)) < mw.TimeFunc().Unix() {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c))
return
case float64:
if int64(v) < mw.TimeFunc().Unix() {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c))
return
}
case json.Number:
n, err := v.Int64()
if err != nil {
mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c))
return
}
if n < mw.TimeFunc().Unix() {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c))
return
}
default:
mw.Unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c))
}

c.Set("JWT_PAYLOAD", claims)
Expand Down Expand Up @@ -728,7 +740,7 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont
}

if mw.KeyFunc != nil {
return jwt.Parse(token, mw.KeyFunc)
return jwt.Parse(token, mw.KeyFunc, mw.ParseOptions...)
}

return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
Expand All @@ -743,13 +755,13 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont
c.Set("JWT_TOKEN", token)

return mw.Key, nil
})
}, mw.ParseOptions...)
}

// ParseTokenString parse jwt token string
func (mw *HertzJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error) {
if mw.KeyFunc != nil {
return jwt.Parse(token, mw.KeyFunc)
return jwt.Parse(token, mw.KeyFunc, mw.ParseOptions...)
}

return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
Expand All @@ -761,7 +773,7 @@ func (mw *HertzJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error)
}

return mw.Key, nil
})
}, mw.ParseOptions...)
}

func (mw *HertzJWTMiddleware) unauthorized(ctx context.Context, c *app.RequestContext, code int, message string) {
Expand Down
49 changes: 49 additions & 0 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package jwt
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -90,6 +91,28 @@ func makeTokenString(SigningAlgorithm, username string) string {
return tokenString
}

func makeTokenStringWithUserID(SigningAlgorithm string, userID int64) string {
if SigningAlgorithm == "" {
SigningAlgorithm = "HS256"
}

token := jwt.New(jwt.GetSigningMethod(SigningAlgorithm))
claims := token.Claims.(jwt.MapClaims)
claims["identity"] = userID
claims["exp"] = time.Now().Add(time.Hour).Unix()
claims["orig_iat"] = time.Now().Unix()
var tokenString string
if SigningAlgorithm == "RS256" {
keyData, _ := ioutil.ReadFile("testdata/jwtRS256.key")
signKey, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData)
tokenString, _ = token.SignedString(signKey)
} else {
tokenString, _ = token.SignedString(key)
}

return tokenString
}

func keyFunc(token *jwt.Token) (interface{}, error) {
cert, err := ioutil.ReadFile("testdata/jwtRS256.key.pub")
if err != nil {
Expand Down Expand Up @@ -533,6 +556,32 @@ func TestAuthorizator(t *testing.T) {
assert.DeepEqual(t, http.StatusOK, w.Code)
}

func TestParseTokenWithJsonNumber(t *testing.T) {
var userID int64 = 64
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
IdentityHandler: func(ctx context.Context, c *app.RequestContext) interface{} {
claims := ExtractClaims(ctx, c)
testNum, err := claims["identity"].(json.Number).Int64()
assert.Nil(t, err)
assert.DeepEqual(t, userID, testNum)
return testNum
},
Unauthorized: func(ctx context.Context, c *app.RequestContext, code int, message string) {
c.String(code, message)
},
ParseOptions: []jwt.ParserOption{jwt.WithJSONNumber()},
})

handler := hertzHandler(authMiddleware)

w := ut.PerformRequest(handler, http.MethodGet, "/auth/hello", nil, ut.Header{Key: "Authorization", Value: "Bearer " + makeTokenStringWithUserID("HS256", userID)})
assert.DeepEqual(t, http.StatusOK, w.Code)
}

func TestClaimsDuringAuthorization(t *testing.T) {
// the middleware to test
authMiddleware, _ := New(&HertzJWTMiddleware{
Expand Down

0 comments on commit 08f5053

Please sign in to comment.