diff --git a/auth.go b/auth.go index 1c587ea..f90ba86 100644 --- a/auth.go +++ b/auth.go @@ -61,6 +61,7 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { ExpectedAudience: cfg.JWT.ExpectedAudience, TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, LoggerProvider: options.loggerProvider, + ClaimsTemplate: options.claimsTemplate, } if cfg.JWT.ClaimsCache.Enabled { @@ -88,6 +89,7 @@ type jwtParserOptions struct { loggerProvider func(ctx context.Context) log.FieldLogger prometheusLibInstanceLabel string trustedIssuerNotFoundFallback jwt.TrustedIssNotFoundFallback + claimsTemplate jwt.Claims } // JWTParserOption is an option for creating JWTParser. @@ -114,6 +116,13 @@ func WithJWTParserTrustedIssuerNotFoundFallback(fallback jwt.TrustedIssNotFoundF } } +// WithJWTParserClaimsTemplate sets the claims template for JWTParser. +func WithJWTParserClaimsTemplate(claimsTemplate jwt.Claims) JWTParserOption { + return func(options *jwtParserOptions) { + options.claimsTemplate = claimsTemplate + } +} + // NewTokenIntrospector creates a new TokenIntrospector with the given configuration, token provider and scope filter. // If cfg.Introspection.ClaimsCache.Enabled or cfg.Introspection.NegativeCache.Enabled is true, // then idptoken.CachingIntrospector created, otherwise - idptoken.Introspector. @@ -122,7 +131,7 @@ func WithJWTParserTrustedIssuerNotFoundFallback(fallback jwt.TrustedIssNotFoundF func NewTokenIntrospector( cfg *Config, tokenProvider idptoken.IntrospectionTokenProvider, - scopeFilter []idptoken.IntrospectionScopeFilterAccessPolicy, + scopeFilter jwt.ScopeFilter, opts ...TokenIntrospectorOption, ) (*idptoken.Introspector, error) { options := tokenIntrospectorOptions{loggerProvider: middleware.GetLoggerFromContext} @@ -159,6 +168,7 @@ func NewTokenIntrospector( HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, options.loggerProvider), AccessTokenScope: cfg.Introspection.AccessTokenScope, LoggerProvider: options.loggerProvider, + ResultTemplate: options.resultTemplate, ScopeFilter: scopeFilter, TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel, @@ -189,6 +199,7 @@ type tokenIntrospectorOptions struct { loggerProvider func(ctx context.Context) log.FieldLogger prometheusLibInstanceLabel string trustedIssuerNotFoundFallback idptoken.TrustedIssNotFoundFallback + resultTemplate idptoken.IntrospectionResult } // TokenIntrospectorOption is an option for creating TokenIntrospector. @@ -218,6 +229,13 @@ func WithTokenIntrospectorTrustedIssuerNotFoundFallback( } } +// WithTokenIntrospectorResultTemplate sets the result template for TokenIntrospector. +func WithTokenIntrospectorResultTemplate(resultTemplate idptoken.IntrospectionResult) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.resultTemplate = resultTemplate + } +} + // Role is a representation of role which may be used for verifying access. type Role struct { Namespace string @@ -225,11 +243,12 @@ type Role struct { } // NewVerifyAccessByRolesInJWT creates a new function which may be used for verifying access by roles in JWT scope. -func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims *jwt.Claims) bool { - return func(_ *http.Request, claims *jwt.Claims) bool { +func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims jwt.Claims) bool { + return func(_ *http.Request, claims jwt.Claims) bool { + claimsScope := claims.GetScope() for i := range roles { - for j := range claims.Scope { - if roles[i].Name == claims.Scope[j].Role && roles[i].Namespace == claims.Scope[j].ResourceNamespace { + for j := range claimsScope { + if roles[i].Name == claimsScope[j].Role && roles[i].Namespace == claimsScope[j].ResourceNamespace { return true } } @@ -239,8 +258,8 @@ func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims *jw } // NewVerifyAccessByRolesInJWTMaker creates a new function which may be used for verifying access by roles in JWT scope given a namespace. -func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool { - return func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool { +func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string) func(r *http.Request, claims jwt.Claims) bool { + return func(roleNames ...string) func(r *http.Request, claims jwt.Claims) bool { roles := make([]Role, 0, len(roleNames)) for i := range roleNames { roles = append(roles, Role{Namespace: namespace, Name: roleNames[i]}) diff --git a/auth_test.go b/auth_test.go index dda49b7..e00fb89 100644 --- a/auth_test.go +++ b/auth_test.go @@ -44,7 +44,7 @@ func TestNewJWTParser(t *gotesting.T) { require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) defer func() { _ = idpSrv.Shutdown(context.Background()) }() - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: idpSrv.URL(), ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), @@ -53,7 +53,7 @@ func TestNewJWTParser(t *gotesting.T) { } token := idptest.MustMakeTokenStringSignedWithTestKey(claims) - claimsWithNamedIssuer := &jwt.Claims{ + claimsWithNamedIssuer := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), @@ -66,7 +66,7 @@ func TestNewJWTParser(t *gotesting.T) { name string token string cfg *Config - expectedClaims *jwt.Claims + expectedClaims jwt.Claims checkFn func(t *gotesting.T, jwtParser JWTParser) }{ { @@ -149,7 +149,7 @@ func TestNewTokenIntrospector(t *gotesting.T) { require.NoError(t, grpcIDPSrv.StartAndWaitForReady(time.Second)) defer func() { grpcIDPSrv.GracefulStop() }() - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: httpIDPSrv.URL(), ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), @@ -158,7 +158,7 @@ func TestNewTokenIntrospector(t *gotesting.T) { } token := idptest.MustMakeTokenStringSignedWithTestKey(claims) - claimsWithNamedIssuer := &jwt.Claims{ + claimsWithNamedIssuer := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), @@ -174,8 +174,8 @@ func TestNewTokenIntrospector(t *gotesting.T) { Role: "admin", ResourcePath: "resource-" + uuid.NewString(), }} - httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ - Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) + httpServerIntrospector.SetResultForToken(opaqueToken, &idptoken.DefaultIntrospectionResult{ + Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}}) grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{ Active: true, TokenType: idputil.TokenTypeBearer, Scope: []*pb.AccessTokenScope{ { @@ -197,10 +197,10 @@ func TestNewTokenIntrospector(t *gotesting.T) { name: "new token introspector, dynamic endpoint, trusted issuers map", cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}}, token: tokenWithNamedIssuer, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: *claimsWithNamedIssuer, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: *claimsWithNamedIssuer, }, checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Empty(t, introspector.ClaimsCache.Len(context.Background())) @@ -211,10 +211,10 @@ func TestNewTokenIntrospector(t *gotesting.T) { name: "new token introspector, dynamic endpoint, trusted issuer urls", cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}}, token: token, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: *claims, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: *claims, }, checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Empty(t, introspector.ClaimsCache.Len(context.Background())) @@ -228,10 +228,10 @@ func TestNewTokenIntrospector(t *gotesting.T) { Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}}, }, token: tokenWithNamedIssuer, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: *claimsWithNamedIssuer, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: *claimsWithNamedIssuer, }, checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) @@ -245,10 +245,10 @@ func TestNewTokenIntrospector(t *gotesting.T) { Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}}, }, token: token, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: *claims, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: *claims, }, checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) @@ -265,10 +265,10 @@ func TestNewTokenIntrospector(t *gotesting.T) { }, }, token: opaqueToken, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{Scope: opaqueTokenScope}, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}, }, checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) @@ -291,10 +291,10 @@ func TestNewTokenIntrospector(t *gotesting.T) { }, }, token: opaqueToken, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{Scope: opaqueTokenScope}, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}, }, checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Empty(t, introspector.ClaimsCache.Len(context.Background())) @@ -323,7 +323,7 @@ func TestNewTokenIntrospector(t *gotesting.T) { } func TestNewVerifyAccessByJWTRoles(t *gotesting.T) { - jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{ + jwtClaims := &jwt.DefaultClaims{Scope: []jwt.AccessPolicy{ {ResourceNamespace: "policy_manager", Role: "admin"}, {ResourceNamespace: "scan_service", Role: "admin"}, {Role: "backup_user"}, @@ -347,7 +347,7 @@ func TestNewVerifyAccessByJWTRoles(t *gotesting.T) { } func TestNewVerifyAccessByJWTRolesMaker(t *gotesting.T) { - jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{ + jwtClaims := &jwt.DefaultClaims{Scope: []jwt.AccessPolicy{ {ResourceNamespace: "policy_manager", Role: "admin"}, {ResourceNamespace: "scan_service", Role: "admin"}, {Role: "backup_user"}, diff --git a/examples/authn-middleware/main.go b/examples/authn-middleware/main.go index 6fac6c5..246fcb6 100644 --- a/examples/authn-middleware/main.go +++ b/examples/authn-middleware/main.go @@ -48,8 +48,9 @@ func runApp() error { srvMux := http.NewServeMux() srvMux.Handle("/", authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context - _, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", jwtClaims.Subject))) + jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context + tokenSubject, _ := jwtClaims.GetSubject() // error is always nil here unless custom claims are used + _, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", tokenSubject))) // use the subject to greet the user }))) if err = http.ListenAndServe(":8080", middleware.Logging(logger)(srvMux)); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("listen and HTTP server: %w", err) diff --git a/examples/idp-test-server/main.go b/examples/idp-test-server/main.go index 80d15ee..56043b2 100644 --- a/examples/idp-test-server/main.go +++ b/examples/idp-test-server/main.go @@ -74,16 +74,17 @@ type demoTokenIntrospector struct { func (dti *demoTokenIntrospector) IntrospectToken(r *http.Request, token string) (idptoken.IntrospectionResult, error) { if bearerToken := authkit.GetBearerTokenFromRequest(r); bearerToken != "access-token-with-introspection-permission" { - return idptoken.IntrospectionResult{}, idptest.ErrUnauthorized + return nil, idptest.ErrUnauthorized } claims, err := dti.jwtParser.Parse(r.Context(), token) if err != nil { - return idptoken.IntrospectionResult{Active: false}, nil + return &idptoken.DefaultIntrospectionResult{Active: false}, nil } - if claims.Subject == "admin2" { - claims.Scope = append(claims.Scope, jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"}) + defClaims := claims.(*jwt.DefaultClaims) // type assertion is safe here since we don't use custom claims + if defClaims.Subject == "admin2" { + defClaims.Scope = append(defClaims.Scope, jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"}) } - return idptoken.IntrospectionResult{Active: true, TokenType: "Bearer", Claims: *claims}, nil + return &idptoken.DefaultIntrospectionResult{Active: true, TokenType: "Bearer", DefaultClaims: *defClaims}, nil } type demoClaimsProvider struct { @@ -92,9 +93,9 @@ type demoClaimsProvider struct { func (dcp *demoClaimsProvider) Provide(r *http.Request) (jwt.Claims, error) { username, password, ok := r.BasicAuth() if !ok { - return jwt.Claims{}, idptest.ErrUnauthorized + return nil, idptest.ErrUnauthorized } - var claims jwt.Claims + claims := &jwt.DefaultClaims{} switch { case username == "user" && password == "user-pwd": claims.Subject = "user" @@ -104,7 +105,7 @@ func (dcp *demoClaimsProvider) Provide(r *http.Request) (jwt.Claims, error) { case username == "admin2" && password == "admin2-pwd": claims.Subject = "admin2" default: - return jwt.Claims{}, idptest.ErrUnauthorized + return nil, idptest.ErrUnauthorized } return claims, nil } diff --git a/examples/token-introspection/grpc-server/main.go b/examples/token-introspection/grpc-server/main.go index bad5e40..0b7efc7 100644 --- a/examples/token-introspection/grpc-server/main.go +++ b/examples/token-introspection/grpc-server/main.go @@ -88,23 +88,25 @@ func (dti *demoGRPCTokenIntrospector) IntrospectToken( if authMeta != "Bearer "+accessTokenWithIntrospectionPermission { return nil, idptest.ErrUnauthorized } + claims, err := dti.jwtParser.Parse(ctx, req.Token) if err != nil { return &pb.IntrospectTokenResponse{Active: false}, nil } - if claims.Subject == "admin2" { - claims.Scope = append(claims.Scope, jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"}) + defClaims := claims.(*jwt.DefaultClaims) // type assertion is safe here since we don't use custom claims + if defClaims.Subject == "admin2" { + defClaims.Scope = append(claims.GetScope(), jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"}) } resp := &pb.IntrospectTokenResponse{ Active: true, TokenType: "Bearer", - Sub: claims.Subject, - Exp: claims.ExpiresAt.Unix(), - Aud: claims.Audience, - Iss: claims.Issuer, - Scope: make([]*pb.AccessTokenScope, 0, len(claims.Scope)), + Sub: defClaims.Subject, + Exp: defClaims.ExpiresAt.Unix(), + Aud: defClaims.Audience, + Iss: defClaims.Issuer, + Scope: make([]*pb.AccessTokenScope, 0, len(defClaims.Scope)), } - for _, policy := range claims.Scope { + for _, policy := range defClaims.Scope { resp.Scope = append(resp.Scope, &pb.AccessTokenScope{ ResourceNamespace: policy.ResourceNamespace, RoleName: policy.Role, diff --git a/examples/token-introspection/main.go b/examples/token-introspection/main.go index b4eb0c7..346db47 100644 --- a/examples/token-introspection/main.go +++ b/examples/token-introspection/main.go @@ -18,7 +18,7 @@ import ( "github.com/acronis/go-appkit/log" "github.com/acronis/go-authkit" - "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" ) const ( @@ -49,8 +49,8 @@ func runApp() error { } // Create token introspector. - introspectionScopeFilter := []idptoken.IntrospectionScopeFilterAccessPolicy{{ResourceNamespace: serviceAccessPolicy}} - tokenIntrospector, err := authkit.NewTokenIntrospector(cfg.Auth, introspectionTokenProvider{}, introspectionScopeFilter) + tokenIntrospector, err := authkit.NewTokenIntrospector(cfg.Auth, + introspectionTokenProvider{}, jwt.ScopeFilter{{ResourceNamespace: serviceAccessPolicy}}) if err != nil { return fmt.Errorf("create token introspector: %w", err) } @@ -79,12 +79,14 @@ func runApp() error { // "/" endpoint will be available for all authenticated users. srvMux.Handle("/", authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context - _, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", jwtClaims.Subject))) + tokenSubject, _ := jwtClaims.GetSubject() // error is always nil here unless custom claims are used + _, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", tokenSubject))) }))) // "/admin" endpoint will be available only for users with the "admin" role. srvMux.Handle("/admin", authZMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // Get JWT claims from the request context. - _, _ = rw.Write([]byte(fmt.Sprintf("Hi, %s", jwtClaims.Subject))) + tokenSubject, _ := jwtClaims.GetSubject() // error is always nil here unless custom claims are used + _, _ = rw.Write([]byte(fmt.Sprintf("Hi, %s", tokenSubject))) }))) if err = http.ListenAndServe(":8080", middleware.Logging(logger)(srvMux)); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("listen and HTTP server: %w", err) diff --git a/idptest/http_server_test.go b/idptest/http_server_test.go index 5151e88..db9bc7b 100644 --- a/idptest/http_server_test.go +++ b/idptest/http_server_test.go @@ -110,7 +110,7 @@ func TestHTTPServerDefault(t *gotesting.T) { respBody, err = io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close()) - var introspectionRespData idptoken.IntrospectionResult + var introspectionRespData idptoken.DefaultIntrospectionResult require.NoError(t, json.Unmarshal(respBody, &introspectionRespData)) require.True(t, introspectionRespData.Active) require.Equal(t, idpSrv.URL(), introspectionRespData.Issuer) diff --git a/idptest/jwt.go b/idptest/jwt.go index 886f318..c8098f1 100644 --- a/idptest/jwt.go +++ b/idptest/jwt.go @@ -14,6 +14,7 @@ import ( "github.com/mendsley/gojwk" "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/jwt" ) // SignToken signs token with key. @@ -33,7 +34,7 @@ func MustSignToken(token *jwtgo.Token, rsaPrivateKey interface{}) string { // MakeTokenStringWithHeader create test signed token with claims and headers. func MakeTokenStringWithHeader( - claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, + claims jwt.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, ) (string, error) { token := jwtgo.NewWithClaims(jwtgo.SigningMethodRS256, claims) token.Header["typ"] = idputil.JWTTypeAccessToken @@ -47,7 +48,7 @@ func MakeTokenStringWithHeader( // MustMakeTokenStringWithHeader create test signed token with claims and headers. // It panics if error occurs. func MustMakeTokenStringWithHeader( - claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, + claims jwt.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, ) string { token, err := MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, header) if err != nil { @@ -57,13 +58,13 @@ func MustMakeTokenStringWithHeader( } // MakeTokenString create signed token with claims. -func MakeTokenString(claims jwtgo.Claims, kid string, rsaPrivateKey interface{}) (string, error) { +func MakeTokenString(claims jwt.Claims, kid string, rsaPrivateKey interface{}) (string, error) { return MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, nil) } // MustMakeTokenString create signed token with claims. // It panics if error occurs. -func MustMakeTokenString(claims jwtgo.Claims, kid string, rsaPrivateKey interface{}) string { +func MustMakeTokenString(claims jwt.Claims, kid string, rsaPrivateKey interface{}) string { token, err := MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, nil) if err != nil { panic(err) @@ -80,14 +81,14 @@ func GetTestRSAPrivateKey() crypto.PrivateKey { } // MakeTokenStringSignedWithTestKey create test token signed with the pre-defined private key (TestKeyID) for testing. -func MakeTokenStringSignedWithTestKey(claims jwtgo.Claims) (string, error) { +func MakeTokenStringSignedWithTestKey(claims jwt.Claims) (string, error) { return MakeTokenStringWithHeader(claims, TestKeyID, GetTestRSAPrivateKey(), nil) } // MustMakeTokenStringSignedWithTestKey create test token signed // with the pre-defined private key (TestKeyID) for testing. // It panics if error occurs. -func MustMakeTokenStringSignedWithTestKey(claims jwtgo.Claims) string { +func MustMakeTokenStringSignedWithTestKey(claims jwt.Claims) string { token, err := MakeTokenStringSignedWithTestKey(claims) if err != nil { panic(err) diff --git a/idptest/jwt_test.go b/idptest/jwt_test.go index d86bfe5..d3f9683 100644 --- a/idptest/jwt_test.go +++ b/idptest/jwt_test.go @@ -28,7 +28,7 @@ func TestMakeTokenStringWithHeader(t *testing.T) { issuerConfigServer := httptest.NewServer(&OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - jwtClaims := &jwt.Claims{ + jwtClaims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), @@ -44,7 +44,7 @@ func TestMakeTokenStringWithHeader(t *testing.T) { require.NoError(t, err) require.Equal( t, - []jwt.AccessPolicy{{ResourceNamespace: "policy_manager", Role: "admin"}}, - parsedClaims.Scope, + jwt.Scope{{ResourceNamespace: "policy_manager", Role: "admin"}}, + parsedClaims.GetScope(), ) } diff --git a/idptest/token_handlers.go b/idptest/token_handlers.go index 03d845d..fdbd6ec 100644 --- a/idptest/token_handlers.go +++ b/idptest/token_handlers.go @@ -37,7 +37,7 @@ func (h *TokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.servedCount.Add(1) - var claims jwt.Claims + var claims jwt.Claims = &jwt.DefaultClaims{} if h.ClaimsProvider != nil { var err error if claims, err = h.ClaimsProvider.Provide(r); err != nil { @@ -49,14 +49,16 @@ func (h *TokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } } - if claims.ID == "" { - claims.ID = uuid.NewString() - } - if claims.ExpiresAt == nil { - claims.ExpiresAt = jwtgo.NewNumericDate(time.Now().Add(time.Hour)) // By default, token expires in 1 hour. - } - if claims.Issuer == "" { - claims.Issuer = h.Issuer + if defaultClaims, ok := claims.(*jwt.DefaultClaims); ok { + if defaultClaims.ID == "" { + defaultClaims.ID = uuid.NewString() + } + if defaultClaims.ExpiresAt == nil { + defaultClaims.ExpiresAt = jwtgo.NewNumericDate(time.Now().Add(time.Hour)) // By default, token expires in 1 hour. + } + if defaultClaims.Issuer == "" { + defaultClaims.Issuer = h.Issuer + } } token, err := MakeTokenStringWithHeader(claims, TestKeyID, GetTestRSAPrivateKey(), nil) @@ -65,7 +67,12 @@ func (h *TokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } - expiresIn := claims.ExpiresAt.Unix() - time.Now().UTC().Unix() + expiresAt, expiresAtErr := claims.GetExpirationTime() + if expiresAtErr != nil { + http.Error(rw, expiresAtErr.Error(), http.StatusInternalServerError) + return + } + expiresIn := expiresAt.Unix() - time.Now().UTC().Unix() if expiresIn < 0 { expiresIn = 0 } @@ -137,9 +144,11 @@ func (h *TokenIntrospectionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Re } } else { if claims, err := h.JWTParser.Parse(r.Context(), token); err == nil { - introspectResult.Active = true - introspectResult.TokenType = idputil.TokenTypeBearer - introspectResult.Claims = *claims + introspectResult = &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: *claims.(*jwt.DefaultClaims), + } } } diff --git a/idptoken/grpc_client.go b/idptoken/grpc_client.go index 0ee9aae..102ad97 100644 --- a/idptoken/grpc_client.go +++ b/idptoken/grpc_client.go @@ -107,7 +107,7 @@ type TokenData struct { // IntrospectToken introspects the token using the IDP token service. func (c *GRPCClient) IntrospectToken( - ctx context.Context, token string, scopeFilter []IntrospectionScopeFilterAccessPolicy, accessToken string, + ctx context.Context, token string, scopeFilter jwt.ScopeFilter, accessToken string, ) (IntrospectionResult, error) { req := pb.IntrospectTokenRequest{ Token: token, @@ -125,30 +125,29 @@ func (c *GRPCClient) IntrospectToken( resp, innerErr = c.client.IntrospectToken(ctx, &req) return innerErr }); err != nil { - return IntrospectionResult{}, err + return nil, err } - res := IntrospectionResult{ - Active: resp.GetActive(), - TokenType: resp.GetTokenType(), - Claims: jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: resp.GetIss(), - Subject: resp.GetSub(), - Audience: resp.GetAud(), - ID: resp.GetJti(), - }, - SubType: resp.GetSubType(), - ClientID: resp.GetClientId(), - OwnerTenantUUID: resp.GetOwnerTenantUuid(), - Scope: make([]jwt.AccessPolicy, len(resp.GetScope())), + claims := jwt.DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: resp.GetIss(), + Subject: resp.GetSub(), + Audience: resp.GetAud(), + ID: resp.GetJti(), }, } if resp.GetExp() != 0 { - res.Claims.ExpiresAt = jwtgo.NewNumericDate(time.Unix(resp.GetExp(), 0)) + claims.ExpiresAt = jwtgo.NewNumericDate(time.Unix(resp.GetExp(), 0)) } + if resp.GetIat() != 0 { + claims.IssuedAt = jwtgo.NewNumericDate(time.Unix(resp.GetIat(), 0)) + } + if resp.GetNbf() != 0 { + claims.NotBefore = jwtgo.NewNumericDate(time.Unix(resp.GetNbf(), 0)) + } + claims.Scope = make([]jwt.AccessPolicy, len(resp.GetScope())) for i, s := range resp.GetScope() { - res.Claims.Scope[i] = jwt.AccessPolicy{ + claims.Scope[i] = jwt.AccessPolicy{ ResourceNamespace: s.GetResourceNamespace(), Role: s.GetRoleName(), ResourceServerID: s.GetResourceServer(), @@ -156,10 +155,14 @@ func (c *GRPCClient) IntrospectToken( TenantUUID: s.GetTenantUuid(), } if s.GetTenantIntId() != 0 { - res.Claims.Scope[i].TenantID = strconv.FormatInt(s.GetTenantIntId(), 10) + claims.Scope[i].TenantID = strconv.FormatInt(s.GetTenantIntId(), 10) } } - return res, nil + return &DefaultIntrospectionResult{ + Active: resp.GetActive(), + TokenType: resp.GetTokenType(), + DefaultClaims: claims, + }, nil } // ExchangeToken exchanges the token requesting a new token with the specified version. diff --git a/idptoken/grpc_client_test.go b/idptoken/grpc_client_test.go index 7923463..9687266 100644 --- a/idptoken/grpc_client_test.go +++ b/idptoken/grpc_client_test.go @@ -26,17 +26,21 @@ import ( func TestGRPCClient_ExchangeToken(t *gotesting.T) { tokenExpiresIn := time.Hour tokenExpiresAt := time.Now().Add(time.Hour) - tokenV1 := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Subject: "test-subject", - ExpiresAt: jwtgo.NewNumericDate(tokenExpiresAt), + tokenV1 := idptest.MustMakeTokenStringSignedWithTestKey(&VersionedClaims{ + DefaultClaims: jwt.DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: "test-subject", + ExpiresAt: jwtgo.NewNumericDate(tokenExpiresAt), + }, }, Version: 1, }) - tokenV2 := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Subject: "test-subject", - ExpiresAt: jwtgo.NewNumericDate(tokenExpiresAt), + tokenV2 := idptest.MustMakeTokenStringSignedWithTestKey(&VersionedClaims{ + DefaultClaims: jwt.DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: "test-subject", + ExpiresAt: jwtgo.NewNumericDate(tokenExpiresAt), + }, }, Version: 2, }) @@ -103,3 +107,15 @@ func TestGRPCClient_ExchangeToken(t *gotesting.T) { }) } } + +type VersionedClaims struct { + jwt.DefaultClaims + Version int `json:"ver"` +} + +func (c *VersionedClaims) Clone() jwt.Claims { + return &VersionedClaims{ + DefaultClaims: *c.DefaultClaims.Clone().(*jwt.DefaultClaims), + Version: c.Version, + } +} diff --git a/idptoken/idp_token.proto b/idptoken/idp_token.proto index f50b6b7..d34723e 100644 --- a/idptoken/idp_token.proto +++ b/idptoken/idp_token.proto @@ -59,7 +59,8 @@ message AccessTokenScope { } message IntrospectTokenResponse { - reserved 12 to 100; + reserved 14 to 100; + reserved 8, 9, 10; bool active = 1; string token_type = 2; int64 exp = 3; @@ -67,8 +68,7 @@ message IntrospectTokenResponse { string jti = 5; string iss = 6; string sub = 7; - string sub_type = 8; - string client_id = 9; - string owner_tenant_uuid = 10; // API client owner tenant UUID. repeated AccessTokenScope scope = 11; + int64 nbf = 12; + int64 iat = 13; } \ No newline at end of file diff --git a/idptoken/introspector.go b/idptoken/introspector.go index 3ac7278..a92ce48 100644 --- a/idptoken/introspector.go +++ b/idptoken/introspector.go @@ -70,6 +70,15 @@ var ErrUnauthenticated = errors.New("request is unauthenticated") // For example, it could be analyzed and then added to the list by calling AddTrustedIssuerURL method. type TrustedIssNotFoundFallback func(ctx context.Context, i *Introspector, iss string) (issURL string, issFound bool) +// IntrospectionResult is an interface that must be implemented by introspection result implementations. +// By default, DefaultIntrospectionResult is used. +type IntrospectionResult interface { + IsActive() bool + GetTokenType() string + GetClaims() jwt.Claims + Clone() IntrospectionResult +} + // IntrospectionTokenProvider is an interface for getting access token for doing introspection. // The token should have introspection permission. type IntrospectionTokenProvider interface { @@ -77,14 +86,16 @@ type IntrospectionTokenProvider interface { Invalidate() } -// IntrospectionScopeFilterAccessPolicy is a single access policy for filtering scope during introspection. -type IntrospectionScopeFilterAccessPolicy struct { - ResourceNamespace string +// IntrospectionCache is an interface that must be implemented by used cache implementations. +// The cache is used for storing results of access token introspection. +type IntrospectionCache interface { + Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionCacheItem, bool) + Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionCacheItem) + Remove(ctx context.Context, key [sha256.Size]byte) bool + Purge(ctx context.Context) + Len(ctx context.Context) int } -// IntrospectionScopeFilter is a filter for scope during introspection. -type IntrospectionScopeFilter []IntrospectionScopeFilterAccessPolicy - // IntrospectorOpts is a set of options for creating Introspector. type IntrospectorOpts struct { // GRPCClient is a gRPC client for doing introspection. @@ -107,7 +118,7 @@ type IntrospectorOpts struct { // ScopeFilter is a filter for scope during introspection. // If it's set, then only access policies in scope that match at least one of the filtering policies will be returned. - ScopeFilter IntrospectionScopeFilter + ScopeFilter jwt.ScopeFilter // LoggerProvider is a function that provides a logger for the Introspector. LoggerProvider func(ctx context.Context) log.FieldLogger @@ -128,6 +139,11 @@ type IntrospectorOpts struct { // EndpointDiscoveryCache is a configuration of how endpoint discovery cache will be used. EndpointDiscoveryCache IntrospectorCacheOpts + + // ResultTemplate is a custom introspection result + // that will be used instead of DefaultIntrospectionResult for unmarshalling introspection response. + // It must implement IntrospectionResult interface. + ResultTemplate IntrospectionResult } // IntrospectorCacheOpts is a configuration of how cache will be used. @@ -148,10 +164,10 @@ type Introspector struct { HTTPClient *http.Client // ClaimsCache is a cache for storing claims of introspected active tokens. - ClaimsCache IntrospectionClaimsCache + ClaimsCache IntrospectionCache // NegativeCache is a cache for storing info about tokens that are not active. - NegativeCache IntrospectionNegativeCache + NegativeCache IntrospectionCache // EndpointDiscoveryCache is a cache for storing OpenID configuration. EndpointDiscoveryCache IntrospectionEndpointDiscoveryCache @@ -160,11 +176,13 @@ type Introspector struct { accessTokenProviderInvalidatedAt atomic.Value accessTokenScope []string + resultTemplate IntrospectionResult + jwtParser *jwtgo.Parser httpEndpoint string - scopeFilter IntrospectionScopeFilter + scopeFilter jwt.ScopeFilter scopeFilterFormURLEncoded string loggerProvider func(ctx context.Context) log.FieldLogger @@ -179,35 +197,35 @@ type Introspector struct { endpointDiscoveryCacheTTL time.Duration } -// IntrospectionResult is a struct for introspection result. -type IntrospectionResult struct { +// DefaultIntrospectionResult is a default implementation of IntrospectionResult. +type DefaultIntrospectionResult struct { Active bool `json:"active"` TokenType string `json:"token_type,omitempty"` - jwt.Claims -} - -// ApplyScopeFilter filters the scope of the introspection result -// and preserves policies only that match the filter if it's not empty. -// It's used just in case when the scope filtering is not done on the introspection endpoint side. -func (ir *IntrospectionResult) ApplyScopeFilter(filter IntrospectionScopeFilter) { - if len(filter) == 0 { - return - } - n := 0 - for j := range ir.Claims.Scope { - matched := false - for k := range filter { - if ir.Claims.Scope[j].ResourceNamespace == filter[k].ResourceNamespace { - matched = true - break - } - } - if matched { - ir.Claims.Scope[n] = ir.Claims.Scope[j] - n++ - } + jwt.DefaultClaims +} + +// IsActive returns true if the token is active. +func (ir *DefaultIntrospectionResult) IsActive() bool { + return ir.Active +} + +// GetTokenType returns the token type. +func (ir *DefaultIntrospectionResult) GetTokenType() string { + return ir.TokenType +} + +// GetClaims returns the claims of the token. +func (ir *DefaultIntrospectionResult) GetClaims() jwt.Claims { + return &ir.DefaultClaims +} + +// Clone returns a deep copy of the introspection result. +func (ir *DefaultIntrospectionResult) Clone() IntrospectionResult { + return &DefaultIntrospectionResult{ + Active: ir.Active, + TokenType: ir.TokenType, + DefaultClaims: *ir.DefaultClaims.Clone().(*jwt.DefaultClaims), } - ir.Claims.Scope = ir.Claims.Scope[:n] } // NewIntrospector creates a new Introspector with the given token provider. @@ -234,11 +252,11 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource) - claimsCache := makeIntrospectionClaimsCache(opts.ClaimsCache, promMetrics) + claimsCache := makeIntrospectionClaimsCache(opts.ClaimsCache, DefaultIntrospectionClaimsCacheMaxEntries, promMetrics) if opts.ClaimsCache.TTL == 0 { opts.ClaimsCache.TTL = DefaultIntrospectionClaimsCacheTTL } - negativeCache := makeIntrospectionNegativeCache(opts.NegativeCache, promMetrics) + negativeCache := makeIntrospectionClaimsCache(opts.NegativeCache, DefaultIntrospectionNegativeCacheMaxEntries, promMetrics) if opts.NegativeCache.TTL == 0 { opts.NegativeCache.TTL = DefaultIntrospectionNegativeCacheTTL } @@ -247,9 +265,18 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt opts.EndpointDiscoveryCache.TTL = DefaultIntrospectionEndpointDiscoveryCacheTTL } + var resultTemplate IntrospectionResult = &DefaultIntrospectionResult{} + if opts.ResultTemplate != nil { + if opts.GRPCClient != nil { + return nil, errors.New("custom introspection result template is not supported when gRPC is used") + } + resultTemplate = opts.ResultTemplate + } + return &Introspector{ accessTokenProvider: accessTokenProvider, accessTokenScope: opts.AccessTokenScope, + resultTemplate: resultTemplate, jwtParser: jwtgo.NewParser(), loggerProvider: opts.LoggerProvider, GRPCClient: opts.GRPCClient, @@ -276,33 +303,32 @@ func (i *Introspector) IntrospectToken(ctx context.Context, token string) (Intro if cachedItem, ok := i.ClaimsCache.Get(ctx, cacheKey); ok { now := time.Now() - if cachedItem.CreatedAt.Add(i.claimsCacheTTL).After(now) && - (cachedItem.Claims.ExpiresAt == nil || cachedItem.Claims.ExpiresAt.Time.After(now)) { - return IntrospectionResult{Active: true, TokenType: cachedItem.TokenType, - Claims: cloneClaims(cachedItem.Claims)}, nil + if cachedItem.CreatedAt.Add(i.claimsCacheTTL).After(now) { + cachedClaimsExpiresAt, err := cachedItem.IntrospectionResult.GetClaims().GetExpirationTime() + if err != nil { + return nil, fmt.Errorf("get expiration time from cached claims: %w", err) + } + if cachedClaimsExpiresAt == nil || cachedClaimsExpiresAt.Time.After(now) { + return cachedItem.IntrospectionResult.Clone(), nil + } } - } - - if c, ok := i.NegativeCache.Get(ctx, cacheKey); ok { - if c.CreatedAt.Add(i.negativeCacheTTL).After(time.Now()) { - return IntrospectionResult{Active: false}, nil + } else if cachedItem, ok = i.NegativeCache.Get(ctx, cacheKey); ok { + if cachedItem.CreatedAt.Add(i.negativeCacheTTL).After(time.Now()) { + return cachedItem.IntrospectionResult.Clone(), nil } } introspectionResult, err := i.introspectToken(ctx, token) if err != nil { - return IntrospectionResult{}, err - } - if introspectionResult.Active { - introspectionResult.ApplyScopeFilter(i.scopeFilter) - claims := cloneClaims(&introspectionResult.Claims) - i.ClaimsCache.Add(ctx, cacheKey, IntrospectionClaimsCacheItem{ - Claims: &claims, - TokenType: introspectionResult.TokenType, - CreatedAt: time.Now(), - }) + return nil, err + } + if introspectionResult.IsActive() { + introspectionResult.GetClaims().ApplyScopeFilter(i.scopeFilter) + i.ClaimsCache.Add(ctx, cacheKey, IntrospectionCacheItem{ + IntrospectionResult: introspectionResult.Clone(), CreatedAt: time.Now()}) } else { - i.NegativeCache.Add(ctx, cacheKey, IntrospectionNegativeCacheItem{CreatedAt: time.Now()}) + i.NegativeCache.Add(ctx, cacheKey, IntrospectionCacheItem{ + IntrospectionResult: introspectionResult.Clone(), CreatedAt: time.Now()}) } return introspectionResult, nil @@ -321,7 +347,7 @@ func (i *Introspector) AddTrustedIssuerURL(issURL string) error { func (i *Introspector) introspectToken(ctx context.Context, token string) (IntrospectionResult, error) { introspectFn, err := i.makeIntrospectFuncForToken(ctx, token) if err != nil { - return IntrospectionResult{}, err + return nil, err } result, err := introspectFn(ctx, token) @@ -330,7 +356,7 @@ func (i *Introspector) introspectToken(ctx context.Context, token string) (Intro } if !errors.Is(err, ErrUnauthenticated) { - return IntrospectionResult{}, err + return nil, err } // If introspection is unauthorized, then invalidate access token provider's cache and try again. @@ -342,7 +368,7 @@ func (i *Introspector) introspectToken(ctx context.Context, token string) (Intro i.accessTokenProviderInvalidatedAt.Store(now) return introspectFn(ctx, token) } - return IntrospectionResult{}, err + return nil, err } type introspectFunc func(ctx context.Context, token string) (IntrospectionResult, error) @@ -395,14 +421,14 @@ func (i *Introspector) makeIntrospectFuncForToken(ctx context.Context, token str ); err != nil { return nil, makeTokenNotIntrospectableError(fmt.Errorf("decode JWT payload: %w", err)) } - var originalClaims jwt.Claims - if err = json.Unmarshal(jwtPayloadBytes, &originalClaims); err != nil { + var regClaims jwtgo.RegisteredClaims + if err = json.Unmarshal(jwtPayloadBytes, ®Claims); err != nil { return nil, makeTokenNotIntrospectableError(fmt.Errorf("unmarshal JWT payload: %w", err)) } - if originalClaims.Issuer == "" { + if regClaims.Issuer == "" { return nil, makeTokenNotIntrospectableError(fmt.Errorf("no issuer found in JWT")) } - issuer = originalClaims.Issuer + issuer = regClaims.Issuer } issuerURL, ok := i.getURLForIssuerWithCallback(ctx, issuer) @@ -432,7 +458,7 @@ func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) i return func(ctx context.Context, token string) (IntrospectionResult, error) { accessToken, err := i.accessTokenProvider.GetToken(ctx, i.accessTokenScope...) if err != nil { - return IntrospectionResult{}, fmt.Errorf("get access token for doing introspection: %w", err) + return nil, fmt.Errorf("get access token for doing introspection: %w", err) } formEncoded := url.Values{"token": {token}}.Encode() if i.scopeFilterFormURLEncoded != "" { @@ -440,7 +466,7 @@ func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) i } req, err := http.NewRequest(http.MethodPost, introspectionEndpointURL, strings.NewReader(formEncoded)) if err != nil { - return IntrospectionResult{}, fmt.Errorf("new request: %w", err) + return nil, fmt.Errorf("new request: %w", err) } req.Header.Set("Authorization", makeBearerToken(accessToken)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -450,7 +476,7 @@ func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) i elapsed := time.Since(startTime) if err != nil { i.promMetrics.ObserveHTTPClientRequest(http.MethodPost, introspectionEndpointURL, 0, elapsed, metrics.HTTPRequestErrorDo) - return IntrospectionResult{}, fmt.Errorf("do request: %w", err) + return nil, fmt.Errorf("do request: %w", err) } defer func() { if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { @@ -463,16 +489,16 @@ func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) i i.promMetrics.ObserveHTTPClientRequest( http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) if resp.StatusCode == http.StatusUnauthorized { - return IntrospectionResult{}, ErrUnauthenticated + return nil, ErrUnauthenticated } - return IntrospectionResult{}, fmt.Errorf("unexpected HTTP code %d for POST %s", resp.StatusCode, introspectionEndpointURL) + return nil, fmt.Errorf("unexpected HTTP code %d for POST %s", resp.StatusCode, introspectionEndpointURL) } - var res IntrospectionResult - if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { + res := i.resultTemplate.Clone() + if err = json.NewDecoder(resp.Body).Decode(res); err != nil { i.promMetrics.ObserveHTTPClientRequest( http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) - return IntrospectionResult{}, fmt.Errorf("decode response body json for POST %s: %w", introspectionEndpointURL, err) + return nil, fmt.Errorf("decode response body json for POST %s: %w", introspectionEndpointURL, err) } i.promMetrics.ObserveHTTPClientRequest(http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, "") @@ -484,11 +510,11 @@ func (i *Introspector) makeIntrospectFuncGRPC() introspectFunc { return func(ctx context.Context, token string) (IntrospectionResult, error) { accessToken, err := i.accessTokenProvider.GetToken(ctx, i.accessTokenScope...) if err != nil { - return IntrospectionResult{}, fmt.Errorf("get access token for doing introspection: %w", err) + return nil, fmt.Errorf("get access token for doing introspection: %w", err) } res, err := i.GRPCClient.IntrospectToken(ctx, token, i.scopeFilter, accessToken) if err != nil { - return IntrospectionResult{}, fmt.Errorf("introspect token: %w", err) + return nil, fmt.Errorf("introspect token: %w", err) } return res, nil } @@ -585,101 +611,36 @@ func checkIntrospectionRequiredByJWTHeader(jwtHeader map[string]interface{}) boo return true } -// CloneClaims clones the given claims deeply. -func cloneClaims(claims *jwt.Claims) jwt.Claims { - if claims == nil { - return jwt.Claims{} - } - newClaims := jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: claims.Issuer, - Subject: claims.Subject, - ID: claims.ID, - }, - Version: claims.Version, - UserID: claims.UserID, - OriginID: claims.OriginID, - ClientID: claims.ClientID, - TOTPTime: claims.TOTPTime, - SubType: claims.SubType, - OwnerTenantUUID: claims.OwnerTenantUUID, - } - if len(claims.Scope) != 0 { - newClaims.Scope = make([]jwt.AccessPolicy, len(claims.Scope)) - copy(newClaims.Scope, claims.Scope) - } - if len(claims.Audience) != 0 { - newClaims.Audience = make(jwtgo.ClaimStrings, len(claims.Audience)) - copy(newClaims.Audience, claims.Audience) - } - if claims.ExpiresAt != nil { - newClaims.ExpiresAt = jwtgo.NewNumericDate(claims.ExpiresAt.Time) - } - if claims.NotBefore != nil { - newClaims.NotBefore = jwtgo.NewNumericDate(claims.NotBefore.Time) - } - if claims.IssuedAt != nil { - newClaims.IssuedAt = jwtgo.NewNumericDate(claims.IssuedAt.Time) - } - return newClaims -} - -type IntrospectionClaimsCacheItem struct { - Claims *jwt.Claims - TokenType string - CreatedAt time.Time -} - -type IntrospectionClaimsCache interface { - Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionClaimsCacheItem, bool) - Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionClaimsCacheItem) - Purge(ctx context.Context) - Len(ctx context.Context) int +type IntrospectionCacheItem struct { + IntrospectionResult IntrospectionResult + CreatedAt time.Time } -func makeIntrospectionClaimsCache(opts IntrospectorCacheOpts, promMetrics *metrics.PrometheusMetrics) IntrospectionClaimsCache { +func makeIntrospectionClaimsCache( + opts IntrospectorCacheOpts, defaultMaxEntries int, promMetrics *metrics.PrometheusMetrics, +) IntrospectionCache { if !opts.Enabled { - return &disabledIntrospectionClaimsCache{} + return &disabledIntrospectionCache{} } if opts.MaxEntries <= 0 { - opts.MaxEntries = DefaultIntrospectionClaimsCacheMaxEntries + opts.MaxEntries = defaultMaxEntries } - cache, _ := lrucache.New[[sha256.Size]byte, IntrospectionClaimsCacheItem]( + cache, _ := lrucache.New[[sha256.Size]byte, IntrospectionCacheItem]( opts.MaxEntries, promMetrics.TokenClaimsCache) // error is always nil here - return &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionClaimsCacheItem]{cache} -} - -type IntrospectionNegativeCacheItem struct { - CreatedAt time.Time -} - -type IntrospectionNegativeCache interface { - Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionNegativeCacheItem, bool) - Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionNegativeCacheItem) - Purge(ctx context.Context) - Len(ctx context.Context) int -} - -func makeIntrospectionNegativeCache(opts IntrospectorCacheOpts, promMetrics *metrics.PrometheusMetrics) IntrospectionNegativeCache { - if !opts.Enabled { - return &disabledIntrospectionNegativeCache{} - } - if opts.TTL == 0 { - opts.TTL = DefaultIntrospectionNegativeCacheTTL - } - if opts.MaxEntries <= 0 { - opts.MaxEntries = DefaultIntrospectionNegativeCacheMaxEntries - } - cache, _ := lrucache.New[[sha256.Size]byte, IntrospectionNegativeCacheItem]( - opts.MaxEntries, promMetrics.TokenNegativeCache) // error is always nil here - return &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionNegativeCacheItem]{cache} + return &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionCacheItem]{cache} } +// IntrospectionEndpointDiscoveryCacheItem is an item in the introspection endpoint discovery cache. type IntrospectionEndpointDiscoveryCacheItem struct { + // IntrospectionEndpoint is an introspection endpoint URL. IntrospectionEndpoint string - CreatedAt time.Time + + // CreatedAt is a time when the item was created in the cache. + CreatedAt time.Time } +// IntrospectionEndpointDiscoveryCache is an interface +// that must be implemented by used endpoint discovery cache implementations. type IntrospectionEndpointDiscoveryCache interface { Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionEndpointDiscoveryCacheItem, bool) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionEndpointDiscoveryCacheItem) @@ -716,6 +677,10 @@ func (a *IntrospectionLRUCache[K, V]) Add(_ context.Context, key K, val V) { a.cache.Add(key, val) } +func (a *IntrospectionLRUCache[K, V]) Remove(_ context.Context, key K) bool { + return a.cache.Remove(key) +} + func (a *IntrospectionLRUCache[K, V]) Purge(ctx context.Context) { a.cache.Purge() } @@ -724,25 +689,18 @@ func (a *IntrospectionLRUCache[K, V]) Len(ctx context.Context) int { return a.cache.Len() } -type disabledIntrospectionClaimsCache struct{} +type disabledIntrospectionCache struct{} -func (c *disabledIntrospectionClaimsCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionClaimsCacheItem, bool) { - return IntrospectionClaimsCacheItem{}, false -} -func (c *disabledIntrospectionClaimsCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionClaimsCacheItem) { +func (c *disabledIntrospectionCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionCacheItem, bool) { + return IntrospectionCacheItem{}, false } -func (c *disabledIntrospectionClaimsCache) Purge(ctx context.Context) {} -func (c *disabledIntrospectionClaimsCache) Len(ctx context.Context) int { return 0 } - -type disabledIntrospectionNegativeCache struct{} - -func (c *disabledIntrospectionNegativeCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionNegativeCacheItem, bool) { - return IntrospectionNegativeCacheItem{}, false +func (c *disabledIntrospectionCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionCacheItem) { } -func (c *disabledIntrospectionNegativeCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionNegativeCacheItem) { +func (c *disabledIntrospectionCache) Remove(ctx context.Context, key [sha256.Size]byte) bool { + return false } -func (c *disabledIntrospectionNegativeCache) Purge(ctx context.Context) {} -func (c *disabledIntrospectionNegativeCache) Len(ctx context.Context) int { return 0 } +func (c *disabledIntrospectionCache) Purge(ctx context.Context) {} +func (c *disabledIntrospectionCache) Len(ctx context.Context) int { return 0 } type disabledIntrospectionEndpointDiscoveryCache struct{} diff --git a/idptoken/introspector_test.go b/idptoken/introspector_test.go index 1cc8439..9054328 100644 --- a/idptoken/introspector_test.go +++ b/idptoken/introspector_test.go @@ -61,7 +61,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { } // Expired JWT - expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: httpIDPSrv.URL(), Subject: uuid.NewString(), @@ -69,7 +69,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), }, }) - httpServerIntrospector.SetResultForToken(expiredJWT, idptoken.IntrospectionResult{Active: false}) + httpServerIntrospector.SetResultForToken(expiredJWT, &idptoken.DefaultIntrospectionResult{Active: false}) // Valid JWT with scope validJWTScope := []jwt.AccessPolicy{{ @@ -84,14 +84,44 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { ID: uuid.NewString(), ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), } - validJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{RegisteredClaims: validJWTClaims}) - httpServerIntrospector.SetResultForToken(validJWT, idptoken.IntrospectionResult{Active: true, - TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}}) - validJWTWithAppTyp := idptest.MustMakeTokenStringWithHeader(jwt.Claims{ + validJWT := idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{RegisteredClaims: validJWTClaims}) + httpServerIntrospector.SetResultForToken(validJWT, &idptoken.DefaultIntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}}) + validJWTWithAppTyp := idptest.MustMakeTokenStringWithHeader(&jwt.DefaultClaims{ RegisteredClaims: validJWTClaims, }, idptest.TestKeyID, idptest.GetTestRSAPrivateKey(), map[string]interface{}{"typ": idputil.JWTTypeAppAccessToken}) - httpServerIntrospector.SetResultForToken(validJWTWithAppTyp, idptoken.IntrospectionResult{Active: true, - TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}}) + httpServerIntrospector.SetResultForToken(validJWTWithAppTyp, &idptoken.DefaultIntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}}) + + // Valid JWT with scope and custom claims fields + customFieldVal := uuid.NewString() + validCustomJWTScope := jwt.Scope{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-1", + }, { + TenantUUID: uuid.NewString(), + ResourceNamespace: "event-manager", + Role: "publisher", + ResourcePath: "resource-1", + }} + validCustomJWTClaims := jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + } + validCustomJWT := idptest.MustMakeTokenStringSignedWithTestKey(&CustomClaims{ + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validCustomJWTClaims}, CustomField: customFieldVal}) + httpServerIntrospector.SetResultForToken(validCustomJWT, &CustomIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + CustomClaims: CustomClaims{ + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validCustomJWTClaims, Scope: validCustomJWTScope}, + CustomField: customFieldVal, + }, + }) // Opaque token opaqueToken := "opaque-token-" + uuid.NewString() @@ -101,8 +131,8 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { Role: "admin", ResourcePath: "resource-" + uuid.NewString(), }} - httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ - Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) + httpServerIntrospector.SetResultForToken(opaqueToken, &idptoken.DefaultIntrospectionResult{ + Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}}) grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{Active: true, TokenType: idputil.TokenTypeBearer, Scope: jwtScopeToGRPC(opaqueTokenScope)}) @@ -144,7 +174,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { }, { name: `error, dynamic introspection endpoint, invalid "typ" field in JWT header`, - tokenToIntrospect: idptest.MustMakeTokenStringWithHeader(jwt.Claims{ + tokenToIntrospect: idptest.MustMakeTokenStringWithHeader(&jwt.DefaultClaims{ RegisteredClaims: validJWTClaims, }, idptest.TestKeyID, idptest.GetTestRSAPrivateKey(), map[string]interface{}{"typ": "invalid"}), checkError: func(t *gotesting.T, err error) { @@ -155,7 +185,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { }, { name: "error, dynamic introspection endpoint, issuer is not trusted", - tokenToIntrospect: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + tokenToIntrospect: idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: "https://untrusted-issuer.com", Subject: uuid.NewString(), @@ -170,7 +200,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { }, { name: "error, dynamic introspection endpoint, issuer is missing in JWT header and payload", - tokenToIntrospect: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + tokenToIntrospect: idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Subject: uuid.NewString(), ID: uuid.NewString(), @@ -184,7 +214,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { }, { name: "error, dynamic introspection endpoint, nri is 1", - tokenToIntrospect: idptest.MustMakeTokenStringWithHeader(jwt.Claims{ + tokenToIntrospect: idptest.MustMakeTokenStringWithHeader(&jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Subject: uuid.NewString(), ID: uuid.NewString(), @@ -197,7 +227,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { }, { name: "error, dynamic introspection endpoint, nri is true", - tokenToIntrospect: idptest.MustMakeTokenStringWithHeader(jwt.Claims{ + tokenToIntrospect: idptest.MustMakeTokenStringWithHeader(&jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Subject: uuid.NewString(), ID: uuid.NewString(), @@ -211,28 +241,49 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { { name: "ok, dynamic introspection endpoint, introspected token is expired JWT", tokenToIntrospect: expiredJWT, - expectedResult: idptoken.IntrospectionResult{Active: false}, + expectedResult: &idptoken.DefaultIntrospectionResult{Active: false}, expectedHTTPSrvCalled: true, }, { name: "ok, dynamic introspection endpoint, introspected token is JWT", tokenToIntrospect: validJWT, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}, }, expectedHTTPSrvCalled: true, }, { name: `ok, dynamic introspection endpoint, introspected token is JWT, "typ" is "application/at+jwt"`, tokenToIntrospect: validJWTWithAppTyp, - expectedResult: idptoken.IntrospectionResult{ + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}, + }, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is JWT, custom claims, filter scope by resource namespace", + introspectorOpts: idptoken.IntrospectorOpts{ + ResultTemplate: &CustomIntrospectionResult{CustomClaims: CustomClaims{}}, + ScopeFilter: jwt.ScopeFilter{{ResourceNamespace: "event-manager"}}, + }, + tokenToIntrospect: validCustomJWT, + expectedResult: &CustomIntrospectionResult{ Active: true, TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}, + CustomClaims: CustomClaims{ + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validCustomJWTClaims, Scope: jwt.Scope{validCustomJWTScope[1]}}, + CustomField: customFieldVal, + }, }, expectedHTTPSrvCalled: true, + expectedHTTPFormVals: url.Values{ + "token": {validCustomJWT}, + "scope_filter[0].rn": {"event-manager"}, + }, }, { name: "ok, static http introspection endpoint, introspected token is opaque", @@ -240,10 +291,10 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { HTTPEndpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, }, tokenToIntrospect: opaqueToken, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{Scope: opaqueTokenScope}, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}, }, expectedHTTPSrvCalled: true, }, @@ -251,16 +302,16 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { name: "ok, static http introspection endpoint, introspected token is opaque, filter scope by resource namespace", introspectorOpts: idptoken.IntrospectorOpts{ HTTPEndpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, - ScopeFilter: []idptoken.IntrospectionScopeFilterAccessPolicy{ + ScopeFilter: jwt.ScopeFilter{ {ResourceNamespace: "account-server"}, {ResourceNamespace: "tenant-manager"}, }, }, tokenToIntrospect: opaqueToken, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{Scope: opaqueTokenScope}, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}, }, expectedHTTPSrvCalled: true, expectedHTTPFormVals: url.Values{ @@ -285,16 +336,16 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { name: "ok, grpc introspection endpoint", introspectorOpts: idptoken.IntrospectorOpts{ GRPCClient: grpcClient, - ScopeFilter: []idptoken.IntrospectionScopeFilterAccessPolicy{ + ScopeFilter: jwt.ScopeFilter{ {ResourceNamespace: "account-server"}, {ResourceNamespace: "tenant-manager"}, }, }, tokenToIntrospect: opaqueToken, - expectedResult: idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{Scope: opaqueTokenScope}, + expectedResult: &idptoken.DefaultIntrospectionResult{ + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}, }, expectedGRPCSrvCalled: true, expectedGRPCScopeFilter: []*pb.IntrospectionScopeFilter{ @@ -367,7 +418,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { defer func() { _ = idpSrv.Shutdown(context.Background()) }() // Expired JWT - expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: idpSrv.URL(), Subject: uuid.NewString(), @@ -375,7 +426,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), }, }) - serverIntrospector.SetResultForToken(expiredJWT, idptoken.IntrospectionResult{Active: false}) + serverIntrospector.SetResultForToken(expiredJWT, &idptoken.DefaultIntrospectionResult{Active: false}) // Valid JWTs with scope validJWT1Scope := []jwt.AccessPolicy{{ @@ -390,9 +441,9 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { ID: uuid.NewString(), ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(2 * time.Hour)), } - valid1JWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{RegisteredClaims: validJWT1Claims}) - serverIntrospector.SetResultForToken(valid1JWT, idptoken.IntrospectionResult{Active: true, - TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}}) + valid1JWT := idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{RegisteredClaims: validJWT1Claims}) + serverIntrospector.SetResultForToken(valid1JWT, &idptoken.DefaultIntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}}) validJWT2Scope := []jwt.AccessPolicy{{ TenantUUID: uuid.NewString(), ResourceNamespace: "account-server", @@ -405,9 +456,9 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { ID: uuid.NewString(), ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), } - valid2JWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{RegisteredClaims: validJWT2Claims}) - serverIntrospector.SetResultForToken(valid2JWT, idptoken.IntrospectionResult{Active: true, - TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}}) + valid2JWT := idptest.MustMakeTokenStringSignedWithTestKey(&jwt.DefaultClaims{RegisteredClaims: validJWT2Claims}) + serverIntrospector.SetResultForToken(valid2JWT, &idptoken.DefaultIntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}}) // Opaque tokens opaqueToken1 := "opaque-token-" + uuid.NewString() @@ -425,18 +476,18 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { Role: "admin", ResourcePath: "resource-" + uuid.NewString(), }} - serverIntrospector.SetResultForToken(opaqueToken1, idptoken.IntrospectionResult{ - Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}) - serverIntrospector.SetResultForToken(opaqueToken2, idptoken.IntrospectionResult{ - Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}) - serverIntrospector.SetResultForToken(opaqueToken3, idptoken.IntrospectionResult{Active: false}) + serverIntrospector.SetResultForToken(opaqueToken1, &idptoken.DefaultIntrospectionResult{ + Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken1Scope}}) + serverIntrospector.SetResultForToken(opaqueToken2, &idptoken.DefaultIntrospectionResult{ + Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken2Scope}}) + serverIntrospector.SetResultForToken(opaqueToken3, &idptoken.DefaultIntrospectionResult{Active: false}) tests := []struct { name string introspectorOpts idptoken.IntrospectorOpts tokens []string expectedSrvCounts []map[string]uint64 - expectedResult []idptoken.IntrospectionResult + expectedResult []*idptoken.DefaultIntrospectionResult checkError []func(t *gotesting.T, err error) checkIntrospector func(t *gotesting.T, introspector *idptoken.Introspector) delay time.Duration @@ -472,12 +523,12 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, EndpointDiscoveryCache: idptoken.IntrospectorCacheOpts{Enabled: true}, }, - tokens: repeat(expiredJWT, 2), + tokens: []string{expiredJWT, expiredJWT}, expectedSrvCounts: []map[string]uint64{ {idptest.TokenIntrospectionEndpointPath: 1, idptest.OpenIDConfigurationPath: 1}, {}, }, - expectedResult: []idptoken.IntrospectionResult{{Active: false}, {Active: false}}, + expectedResult: []*idptoken.DefaultIntrospectionResult{{Active: false}, {Active: false}}, checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) @@ -498,26 +549,26 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { {idptest.TokenIntrospectionEndpointPath: 1}, {}, }, - expectedResult: []idptoken.IntrospectionResult{ + expectedResult: []*idptoken.DefaultIntrospectionResult{ { - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}, + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}, }, { - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}, + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}, }, { - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}, + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}, }, { - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}, + Active: true, + TokenType: idputil.TokenTypeBearer, + DefaultClaims: jwt.DefaultClaims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}, }, }, checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { @@ -543,11 +594,11 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { {idptest.TokenIntrospectionEndpointPath: 1}, {idptest.TokenIntrospectionEndpointPath: 0}, }, - expectedResult: []idptoken.IntrospectionResult{ - {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, - {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, - {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}, - {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}, + expectedResult: []*idptoken.DefaultIntrospectionResult{ + {Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken2Scope}}, + {Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken2Scope}}, {Active: false}, {Active: false}, }, @@ -572,9 +623,9 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { {idptest.TokenIntrospectionEndpointPath: 1}, {idptest.TokenIntrospectionEndpointPath: 1}, }, - expectedResult: []idptoken.IntrospectionResult{ - {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, - {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + expectedResult: []*idptoken.DefaultIntrospectionResult{ + {Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueToken1Scope}}, {Active: false}, {Active: false}, }, @@ -625,10 +676,40 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { } } -func repeat[V any](v V, n int) []V { - s := make([]V, n) - for i := range s { - s[i] = v +type CustomClaims struct { + jwt.DefaultClaims + CustomField string `json:"custom_field"` +} + +func (c *CustomClaims) Clone() jwt.Claims { + return &CustomClaims{ + DefaultClaims: *c.DefaultClaims.Clone().(*jwt.DefaultClaims), + CustomField: c.CustomField, + } +} + +type CustomIntrospectionResult struct { + Active bool `json:"active"` + TokenType string `json:"token_type,omitempty"` + CustomClaims +} + +func (ir *CustomIntrospectionResult) IsActive() bool { + return ir.Active +} + +func (ir *CustomIntrospectionResult) GetTokenType() string { + return ir.TokenType +} + +func (ir *CustomIntrospectionResult) GetClaims() jwt.Claims { + return &ir.CustomClaims +} + +func (ir *CustomIntrospectionResult) Clone() idptoken.IntrospectionResult { + return &CustomIntrospectionResult{ + Active: ir.Active, + TokenType: ir.TokenType, + CustomClaims: *ir.CustomClaims.Clone().(*CustomClaims), } - return s } diff --git a/idptoken/pb/idp_token.pb.go b/idptoken/pb/idp_token.pb.go index 159f472..23d1cf9 100644 --- a/idptoken/pb/idp_token.pb.go +++ b/idptoken/pb/idp_token.pb.go @@ -345,17 +345,16 @@ type IntrospectTokenResponse struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Active bool `protobuf:"varint,1,opt,name=active,proto3" json:"active,omitempty"` - TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` - Exp int64 `protobuf:"varint,3,opt,name=exp,proto3" json:"exp,omitempty"` - Aud []string `protobuf:"bytes,4,rep,name=aud,proto3" json:"aud,omitempty"` - Jti string `protobuf:"bytes,5,opt,name=jti,proto3" json:"jti,omitempty"` - Iss string `protobuf:"bytes,6,opt,name=iss,proto3" json:"iss,omitempty"` - Sub string `protobuf:"bytes,7,opt,name=sub,proto3" json:"sub,omitempty"` - SubType string `protobuf:"bytes,8,opt,name=sub_type,json=subType,proto3" json:"sub_type,omitempty"` - ClientId string `protobuf:"bytes,9,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` - OwnerTenantUuid string `protobuf:"bytes,10,opt,name=owner_tenant_uuid,json=ownerTenantUuid,proto3" json:"owner_tenant_uuid,omitempty"` // API client owner tenant UUID. - Scope []*AccessTokenScope `protobuf:"bytes,11,rep,name=scope,proto3" json:"scope,omitempty"` + Active bool `protobuf:"varint,1,opt,name=active,proto3" json:"active,omitempty"` + TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` + Exp int64 `protobuf:"varint,3,opt,name=exp,proto3" json:"exp,omitempty"` + Aud []string `protobuf:"bytes,4,rep,name=aud,proto3" json:"aud,omitempty"` + Jti string `protobuf:"bytes,5,opt,name=jti,proto3" json:"jti,omitempty"` + Iss string `protobuf:"bytes,6,opt,name=iss,proto3" json:"iss,omitempty"` + Sub string `protobuf:"bytes,7,opt,name=sub,proto3" json:"sub,omitempty"` + Scope []*AccessTokenScope `protobuf:"bytes,11,rep,name=scope,proto3" json:"scope,omitempty"` + Nbf int64 `protobuf:"varint,12,opt,name=nbf,proto3" json:"nbf,omitempty"` + Iat int64 `protobuf:"varint,13,opt,name=iat,proto3" json:"iat,omitempty"` } func (x *IntrospectTokenResponse) Reset() { @@ -439,32 +438,25 @@ func (x *IntrospectTokenResponse) GetSub() string { return "" } -func (x *IntrospectTokenResponse) GetSubType() string { - if x != nil { - return x.SubType - } - return "" -} - -func (x *IntrospectTokenResponse) GetClientId() string { +func (x *IntrospectTokenResponse) GetScope() []*AccessTokenScope { if x != nil { - return x.ClientId + return x.Scope } - return "" + return nil } -func (x *IntrospectTokenResponse) GetOwnerTenantUuid() string { +func (x *IntrospectTokenResponse) GetNbf() int64 { if x != nil { - return x.OwnerTenantUuid + return x.Nbf } - return "" + return 0 } -func (x *IntrospectTokenResponse) GetScope() []*AccessTokenScope { +func (x *IntrospectTokenResponse) GetIat() int64 { if x != nil { - return x.Scope + return x.Iat } - return nil + return 0 } var File_idp_token_proto protoreflect.FileDescriptor @@ -516,7 +508,7 @@ var file_idp_token_proto_rawDesc = []byte{ 0x52, 0x0c, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x1b, 0x0a, 0x09, 0x72, 0x6f, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x4a, 0x04, 0x08, 0x07, 0x10, - 0x33, 0x22, 0xc7, 0x02, 0x0a, 0x17, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, + 0x33, 0x22, 0x99, 0x02, 0x0a, 0x17, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74, @@ -526,30 +518,27 @@ var file_idp_token_proto_rawDesc = []byte{ 0x03, 0x28, 0x09, 0x52, 0x03, 0x61, 0x75, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6a, 0x74, 0x69, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6a, 0x74, 0x69, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x73, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x69, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, - 0x73, 0x75, 0x62, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x75, 0x62, 0x12, 0x19, - 0x0a, 0x08, 0x73, 0x75, 0x62, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x73, 0x75, 0x62, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2a, 0x0a, 0x11, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x5f, - 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0f, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x55, 0x75, - 0x69, 0x64, 0x12, 0x31, 0x0a, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x0b, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1b, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x41, 0x63, - 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x52, 0x05, - 0x73, 0x63, 0x6f, 0x70, 0x65, 0x4a, 0x04, 0x08, 0x0c, 0x10, 0x65, 0x32, 0xb9, 0x01, 0x0a, 0x0f, - 0x49, 0x44, 0x50, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, - 0x4c, 0x0a, 0x0b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1d, - 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, - 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x58, 0x0a, - 0x0f, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, - 0x12, 0x21, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, - 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, - 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, 0x62, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x75, 0x62, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x75, 0x62, 0x12, 0x31, + 0x0a, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, + 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x70, + 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6e, 0x62, 0x66, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, + 0x6e, 0x62, 0x66, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x61, 0x74, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x03, 0x69, 0x61, 0x74, 0x4a, 0x04, 0x08, 0x0e, 0x10, 0x65, 0x4a, 0x04, 0x08, 0x08, 0x10, + 0x09, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x4a, 0x04, 0x08, 0x0a, 0x10, 0x0b, 0x32, 0xb9, 0x01, + 0x0a, 0x0f, 0x49, 0x44, 0x50, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x12, 0x4c, 0x0a, 0x0b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, + 0x12, 0x1d, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1e, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x58, 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x12, 0x21, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, + 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, + 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/idptoken/provider_test.go b/idptoken/provider_test.go index 2191e3c..a6165aa 100644 --- a/idptoken/provider_test.go +++ b/idptoken/provider_test.go @@ -439,7 +439,7 @@ type claimsProviderWithExpiration struct { } func (d *claimsProviderWithExpiration) Provide(_ *http.Request) (jwt.Claims, error) { - claims := jwt.Claims{ + claims := &jwt.DefaultClaims{ // nolint:staticcheck // StandardClaims are used here for test purposes RegisteredClaims: jwtgo.RegisteredClaims{ ID: uuid.NewString(), @@ -452,8 +452,6 @@ func (d *claimsProviderWithExpiration) Provide(_ *http.Request) (jwt.Claims, err Role: "tenant:viewer", }, }, - Version: 1, - UserID: "1", } if d.ExpTime <= 0 { diff --git a/internal/testing/server_token_introspector_mock.go b/internal/testing/server_token_introspector_mock.go index 98ade72..cb9aef0 100644 --- a/internal/testing/server_token_introspector_mock.go +++ b/internal/testing/server_token_introspector_mock.go @@ -24,7 +24,7 @@ import ( ) type JWTParser interface { - Parse(ctx context.Context, token string) (*jwt.Claims, error) + Parse(ctx context.Context, token string) (jwt.Claims, error) } type HTTPServerTokenIntrospectorMock struct { @@ -69,7 +69,7 @@ func (m *HTTPServerTokenIntrospectorMock) IntrospectToken( m.LastFormValues = r.Form if m.LastAuthorizationHeader != "Bearer "+m.accessTokenForIntrospection { - return idptoken.IntrospectionResult{}, idptest.ErrUnauthorized + return nil, idptest.ErrUnauthorized } if result, ok := m.introspectionResults[tokenToKey(token)]; ok { @@ -78,10 +78,11 @@ func (m *HTTPServerTokenIntrospectorMock) IntrospectToken( claims, err := m.JWTParser.Parse(r.Context(), token) if err != nil { - return idptoken.IntrospectionResult{Active: false}, nil + return &idptoken.DefaultIntrospectionResult{Active: false}, nil } - result := idptoken.IntrospectionResult{Active: true, TokenType: idputil.TokenTypeBearer, Claims: *claims} - if scopes, ok := m.jwtScopes[claims.ID]; ok { + defaultClaims := claims.(*jwt.DefaultClaims) + result := &idptoken.DefaultIntrospectionResult{Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: *defaultClaims} + if scopes, ok := m.jwtScopes[defaultClaims.ID]; ok { result.Scope = scopes } return result, nil @@ -152,19 +153,17 @@ func (m *GRPCServerTokenIntrospectorMock) IntrospectToken( if err != nil { return &pb.IntrospectTokenResponse{Active: false}, nil } + defaultClaims := claims.(*jwt.DefaultClaims) result := &pb.IntrospectTokenResponse{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Exp: claims.ExpiresAt.Unix(), - Aud: claims.Audience, - Jti: claims.ID, - Iss: claims.Issuer, - Sub: claims.Subject, - SubType: claims.SubType, - ClientId: claims.ClientID, - OwnerTenantUuid: claims.OwnerTenantUUID, + Active: true, + TokenType: idputil.TokenTypeBearer, + Exp: defaultClaims.ExpiresAt.Unix(), + Aud: defaultClaims.Audience, + Jti: defaultClaims.ID, + Iss: defaultClaims.Issuer, + Sub: defaultClaims.Subject, } - if scopes, ok := m.scopes[claims.ID]; ok { + if scopes, ok := m.scopes[defaultClaims.ID]; ok { result.Scope = scopes } return result, nil diff --git a/jwt/caching_parser.go b/jwt/caching_parser.go index edc80c1..b7fa6ae 100644 --- a/jwt/caching_parser.go +++ b/jwt/caching_parser.go @@ -28,8 +28,8 @@ type CachingParserOpts struct { // ClaimsCache is an interface that must be implemented by used cache implementations. type ClaimsCache interface { - Get(key [sha256.Size]byte) (*Claims, bool) - Add(key [sha256.Size]byte, value *Claims) + Get(key [sha256.Size]byte) (Claims, bool) + Add(key [sha256.Size]byte, claims Claims) Purge() Len() int } @@ -37,7 +37,8 @@ type ClaimsCache interface { // CachingParser uses the functionality of Parser to parse JWT, but stores resulted Claims objects in the cache. type CachingParser struct { *Parser - ClaimsCache ClaimsCache + ClaimsCache ClaimsCache + claimsValidator *jwtgo.Validator } func NewCachingParser(keysProvider KeysProvider) (*CachingParser, error) { @@ -51,13 +52,14 @@ func NewCachingParserWithOpts( if opts.CacheMaxEntries == 0 { opts.CacheMaxEntries = DefaultClaimsCacheMaxEntries } - cache, err := lrucache.New[[sha256.Size]byte, *Claims](opts.CacheMaxEntries, promMetrics.TokenClaimsCache) + cache, err := lrucache.New[[sha256.Size]byte, Claims](opts.CacheMaxEntries, promMetrics.TokenClaimsCache) if err != nil { return nil, err } return &CachingParser{ - Parser: NewParserWithOpts(keysProvider, opts.ParserOpts), - ClaimsCache: cache, + Parser: NewParserWithOpts(keysProvider, opts.ParserOpts), + ClaimsCache: cache, + claimsValidator: jwtgo.NewValidator(jwtgo.WithExpirationRequired()), }, nil } @@ -73,7 +75,7 @@ func stringToBytesUnsafe(s string) []byte { } // Parse calls Parse method of embedded original Parser but stores result into cache. -func (cp *CachingParser) Parse(ctx context.Context, token string) (*Claims, error) { +func (cp *CachingParser) Parse(ctx context.Context, token string) (Claims, error) { key := getTokenHash(stringToBytesUnsafe(token)) cachedClaims, foundInCache, validationErr := cp.getFromCacheAndValidateIfNeeded(key) if foundInCache { @@ -90,13 +92,13 @@ func (cp *CachingParser) Parse(ctx context.Context, token string) (*Claims, erro return claims, nil } -func (cp *CachingParser) getFromCacheAndValidateIfNeeded(key [sha256.Size]byte) (claims *Claims, found bool, err error) { +func (cp *CachingParser) getFromCacheAndValidateIfNeeded(key [sha256.Size]byte) (claims Claims, found bool, err error) { cachedClaims, found := cp.ClaimsCache.Get(key) if !found { return nil, false, nil } if !cp.Parser.skipClaimsValidation { - if err = cp.Parser.claimsValidator.Validate(cachedClaims); err != nil { + if err = cp.claimsValidator.Validate(cachedClaims); err != nil { return nil, true, fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidClaims, err) } if err = cp.Parser.customValidator(cachedClaims); err != nil { diff --git a/jwt/caching_parser_test.go b/jwt/caching_parser_test.go index e7648e6..a7d0296 100644 --- a/jwt/caching_parser_test.go +++ b/jwt/caching_parser_test.go @@ -27,7 +27,7 @@ func getTokenHash(token []byte) [sha256.Size]byte { } func TestGetTokenHash(t *testing.T) { - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} tokenString := []byte(idptest.MustMakeTokenStringSignedWithTestKey(claims)) th := getTokenHash(tokenString) @@ -35,7 +35,7 @@ func TestGetTokenHash(t *testing.T) { th2 := getTokenHash(tokenString) require.Equal(t, th, th2, "two hashes of the same token must be equal") - claims2 := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: "other" + testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(12 * time.Minute))}} + claims2 := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: "other" + testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(12 * time.Minute))}} tokenString2 := []byte(idptest.MustMakeTokenStringSignedWithTestKey(claims2)) th3 := getTokenHash(tokenString2) require.NotEqual(t, th, th3, "two hashes of different tokens must be different") @@ -48,17 +48,17 @@ func TestCachingParser_Parse(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) parser, err := jwt.NewCachingParser(jwks.NewCachingClient()) require.NoError(t, err) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) - var parsedClaims *jwt.Claims + var parsedClaims jwt.Claims parsedClaims, err = parser.Parse(context.Background(), tokenString) require.NoError(t, err, "caching parser must not return error from Parse method") - require.Equal(t, claims.Scope, parsedClaims.Scope, "unexpected claims value produced by caching parser") + require.Equal(t, claims.Scope, parsedClaims.GetScope(), "unexpected claims value produced by caching parser") require.Equal(t, 1, parser.ClaimsCache.Len(), "one claims object must be cached after successful parse operation") @@ -66,7 +66,7 @@ func TestCachingParser_Parse(t *testing.T) { tokenKey := getTokenHash([]byte(tokenString)) cachedClaims, found := parser.ClaimsCache.Get(tokenKey) require.True(t, found, "cached claims object must be found by token hash") - require.Equal(t, claims.Scope, cachedClaims.Scope, "unexpected claims value fetched from parser cache") + require.Equal(t, claims.Scope, cachedClaims.GetScope(), "unexpected claims value fetched from parser cache") parser.InvalidateClaimsCache() require.Equal(t, 0, parser.ClaimsCache.Len(), @@ -82,17 +82,17 @@ func TestCachingParser_CheckExpiration(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(jwtTTL))}} + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(jwtTTL))}} tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) parser, err := jwt.NewCachingParser(jwks.NewCachingClient()) require.NoError(t, err) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) - var parsedClaims *jwt.Claims + var parsedClaims jwt.Claims parsedClaims, err = parser.Parse(context.Background(), tokenString) require.NoError(t, err, "caching parser must not return error from Parse method") - require.Equal(t, claims.Scope, parsedClaims.Scope, "unexpected claims value produced by caching parser") + require.Equal(t, claims.Scope, parsedClaims.GetScope(), "unexpected claims value produced by caching parser") require.Equal(t, 1, parser.ClaimsCache.Len(), "one claims object must be cached after successful parse operation") diff --git a/jwt/claims.go b/jwt/claims.go new file mode 100644 index 0000000..8f836f9 --- /dev/null +++ b/jwt/claims.go @@ -0,0 +1,124 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import jwtgo "github.com/golang-jwt/jwt/v5" + +// Scope is a slice of access policies. +type Scope []AccessPolicy + +// Claims is an interface that extends jwt.Claims from the "github.com/golang-jwt/jwt/v5" +// with additional methods for working with access policies. +type Claims interface { + jwtgo.Claims + + // GetScope returns the scope of the claims as a slice of access policies. + GetScope() Scope + + // Clone returns a deep copy of the claims. + Clone() Claims + + // ApplyScopeFilter filters (in-place) the scope of the claims by the specified filter. + ApplyScopeFilter(filter ScopeFilter) +} + +// DefaultClaims is a struct that extends jwt.RegisteredClaims with a custom scope field. +// It may be embedded into custom claims structs if additional fields are required. +type DefaultClaims struct { + jwtgo.RegisteredClaims + Scope Scope `json:"scope,omitempty"` +} + +// GetScope returns the scope of the DefaultClaims as a slice of access policies. +func (c *DefaultClaims) GetScope() Scope { + return c.Scope +} + +// Clone returns a deep copy of the DefaultClaims. +func (c *DefaultClaims) Clone() Claims { + newClaims := &DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: c.Issuer, + Subject: c.Subject, + ID: c.ID, + }, + } + if len(c.Scope) != 0 { + newClaims.Scope = make([]AccessPolicy, len(c.Scope)) + copy(newClaims.Scope, c.Scope) + } + if len(c.Audience) != 0 { + newClaims.Audience = make(jwtgo.ClaimStrings, len(c.Audience)) + copy(newClaims.Audience, c.Audience) + } + if c.ExpiresAt != nil { + newClaims.ExpiresAt = jwtgo.NewNumericDate(c.ExpiresAt.Time) + } + if c.NotBefore != nil { + newClaims.NotBefore = jwtgo.NewNumericDate(c.NotBefore.Time) + } + if c.IssuedAt != nil { + newClaims.IssuedAt = jwtgo.NewNumericDate(c.IssuedAt.Time) + } + return newClaims +} + +// ScopeFilter is a slice of access policy filters. +type ScopeFilter []ScopeFilterAccessPolicy + +// ScopeFilterAccessPolicy is a struct that represents a single access policy filter. +type ScopeFilterAccessPolicy struct { + ResourceNamespace string +} + +// ApplyScopeFilter filters (in-place) the scope of the DefaultClaims by the specified filter. +func (c *DefaultClaims) ApplyScopeFilter(filter ScopeFilter) { + if len(filter) == 0 { + return + } + n := 0 + for j := range c.Scope { + matched := false + for k := range filter { + if c.Scope[j].ResourceNamespace == filter[k].ResourceNamespace { + matched = true + break + } + } + if matched { + c.Scope[n] = c.Scope[j] + n++ + } + } + c.Scope = c.Scope[:n] +} + +// AccessPolicy represents a single access policy which specifies access rights to a tenant or resource +// in the scope of a resource server. +type AccessPolicy struct { + // TenantID is a unique identifier of tenant for which access is granted (if resource is not specified) + // or which the resource is owned by (if resource is specified). + TenantID string `json:"tid,omitempty"` + + // TenantUUID is a UUID of tenant for which access is granted (if the resource is not specified) + // or which the resource is owned by (if the resource is specified). + TenantUUID string `json:"tuid,omitempty"` + + // ResourceServerID is a unique resource server instance or cluster ID. + ResourceServerID string `json:"rs,omitempty"` + + // ResourceNamespace is a namespace to which resource belongs within resource server. + // E.g.: account-server, storage-manager, task-manager, alert-manager, etc. + ResourceNamespace string `json:"rn,omitempty"` + + // ResourcePath is a unique identifier of or path to a single resource or resource collection + // in the scope of the resource server and namespace. + ResourcePath string `json:"rp,omitempty"` + + // Role determines what actions are allowed to be performed on the specified tenant or resource. + Role string `json:"role,omitempty"` +} diff --git a/jwt/claims_test.go b/jwt/claims_test.go new file mode 100644 index 0000000..000c674 --- /dev/null +++ b/jwt/claims_test.go @@ -0,0 +1,180 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt_test + +import ( + "testing" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/jwt" +) + +func TestDefaultClaims_ApplyScopeFilter(t *testing.T) { + tests := []struct { + name string + claims jwt.DefaultClaims + filter jwt.ScopeFilter + expectedScope jwt.Scope + }{ + { + name: "no filter", + claims: jwt.DefaultClaims{ + Scope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + filter: nil, + expectedScope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + { + name: "filter matches all", + claims: jwt.DefaultClaims{ + Scope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + filter: jwt.ScopeFilter{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + expectedScope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + { + name: "filter matches some", + claims: jwt.DefaultClaims{ + Scope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + filter: jwt.ScopeFilter{ + {ResourceNamespace: "namespace1"}, + }, + expectedScope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + }, + }, + { + name: "filter matches none", + claims: jwt.DefaultClaims{ + Scope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + filter: jwt.ScopeFilter{ + {ResourceNamespace: "namespace3"}, + }, + expectedScope: jwt.Scope{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.claims.ApplyScopeFilter(tt.filter) + require.Equal(t, tt.expectedScope, tt.claims.Scope) + }) + } +} + +func TestDefaultClaims_Clone(t *testing.T) { + tests := []struct { + name string + claims jwt.DefaultClaims + }{ + { + name: "empty claims", + claims: jwt.DefaultClaims{}, + }, + { + name: "claims with jwt.Scope", + claims: jwt.DefaultClaims{ + Scope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + }, + { + name: "claims with registered fields", + claims: jwt.DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "issuer", + Subject: "subject", + Audience: []string{"audience1", "audience2"}, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + NotBefore: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + IssuedAt: jwtgo.NewNumericDate(time.Now()), + ID: "id", + }, + }, + }, + { + name: "claims with all fields", + claims: jwt.DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "issuer", + Subject: "subject", + Audience: []string{"audience1", "audience2"}, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + NotBefore: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + IssuedAt: jwtgo.NewNumericDate(time.Now()), + ID: "id", + }, + Scope: jwt.Scope{ + {ResourceNamespace: "namespace1"}, + {ResourceNamespace: "namespace2"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clone := tt.claims.Clone().(*jwt.DefaultClaims) + require.Equal(t, tt.claims, *clone) + require.NotSame(t, &tt.claims, clone) + + // Modify original claims and ensure clone is not affected + if len(tt.claims.Scope) > 0 { + tt.claims.Scope[0].ResourceNamespace = "modified" + require.NotEqual(t, tt.claims.Scope[0].ResourceNamespace, clone.Scope[0].ResourceNamespace) + } + if len(tt.claims.Audience) > 0 { + tt.claims.Audience[0] = "modified" + require.NotEqual(t, tt.claims.Audience[0], clone.Audience[0]) + } + if tt.claims.ExpiresAt != nil { + tt.claims.ExpiresAt.Time = time.Now().Add(2 * time.Hour) + require.NotEqual(t, tt.claims.ExpiresAt, clone.ExpiresAt) + } + }) + } +} + +type CustomClaims struct { + jwt.DefaultClaims + CustomField string `json:"custom_field"` +} + +func (c *CustomClaims) Clone() jwt.Claims { + return &CustomClaims{ + DefaultClaims: *c.DefaultClaims.Clone().(*jwt.DefaultClaims), + CustomField: c.CustomField, + } +} diff --git a/jwt/errors.go b/jwt/errors.go index fb3d89d..bcce08d 100644 --- a/jwt/errors.go +++ b/jwt/errors.go @@ -8,6 +8,8 @@ package jwt import ( "fmt" + + jwtgo "github.com/golang-jwt/jwt/v5" ) // SignAlgUnknownError represents an error when JWT signing algorithm is unknown. @@ -21,16 +23,17 @@ func (e *SignAlgUnknownError) Error() string { // IssuerUntrustedError represents an error when JWT issuer is untrusted. type IssuerUntrustedError struct { - Claims *Claims + Claims Claims + Issuer string } func (e *IssuerUntrustedError) Error() string { - return fmt.Sprintf("JWT issuer %q untrusted", e.Claims.Issuer) + return fmt.Sprintf("JWT issuer %q untrusted", e.Issuer) } // IssuerMissingError represents an error when JWT issuer is missing. type IssuerMissingError struct { - Claims *Claims + Claims Claims } func (e *IssuerMissingError) Error() string { @@ -39,7 +42,7 @@ func (e *IssuerMissingError) Error() string { // AudienceMissingError represents an error when JWT audience is missing, but it's required. type AudienceMissingError struct { - Claims *Claims + Claims Claims } func (e *AudienceMissingError) Error() string { @@ -48,9 +51,10 @@ func (e *AudienceMissingError) Error() string { // AudienceNotExpectedError represents an error when JWT contains not expected audience. type AudienceNotExpectedError struct { - Claims *Claims + Claims Claims + Audience jwtgo.ClaimStrings } func (e *AudienceNotExpectedError) Error() string { - return fmt.Sprintf("JWT audience %q not expected", e.Claims.Audience) + return fmt.Sprintf("JWT audience %v not expected", e.Audience) } diff --git a/jwt/parser.go b/jwt/parser.go index 6177980..1ea430b 100644 --- a/jwt/parser.go +++ b/jwt/parser.go @@ -37,6 +37,7 @@ type ParserOpts struct { ExpectedAudience []string TrustedIssuerNotFoundFallback TrustedIssNotFoundFallback LoggerProvider func(ctx context.Context) log.FieldLogger + ClaimsTemplate Claims } type audienceMatcher func(aud string) bool @@ -48,8 +49,8 @@ type TrustedIssNotFoundFallback func(ctx context.Context, p *Parser, iss string) // Parser is an object for parsing, validation and verification JWT. type Parser struct { parser *jwtgo.Parser - claimsValidator *jwtgo.Validator - customValidator func(claims *Claims) error + claimsTemplate Claims + customValidator func(claims Claims) error skipClaimsValidation bool keysProvider KeysProvider @@ -74,15 +75,19 @@ func NewParserWithOpts(keysProvider KeysProvider, opts ParserOpts) *Parser { if opts.SkipClaimsValidation { parserOpts = append(parserOpts, jwtgo.WithoutClaimsValidation()) } + var claimsTemplate Claims = &DefaultClaims{} + if opts.ClaimsTemplate != nil { + claimsTemplate = opts.ClaimsTemplate + } return &Parser{ parser: jwtgo.NewParser(parserOpts...), - claimsValidator: jwtgo.NewValidator(jwtgo.WithExpirationRequired()), customValidator: makeCustomAudienceValidator(opts.RequireAudience, audienceMatchers), skipClaimsValidation: opts.SkipClaimsValidation, keysProvider: keysProvider, trustedIssuerStore: idputil.NewTrustedIssuerStore(), trustedIssuerNotFoundFallback: opts.TrustedIssuerNotFoundFallback, loggerProvider: opts.LoggerProvider, + claimsTemplate: claimsTemplate, } } @@ -102,10 +107,10 @@ func (p *Parser) GetURLForIssuer(issuer string) (string, bool) { } // Parse parses, validates and verifies passed token (it's string representation). Parsed claims is returned. -func (p *Parser) Parse(ctx context.Context, token string) (*Claims, error) { +func (p *Parser) Parse(ctx context.Context, token string) (Claims, error) { keyFunc := p.getKeyFunc(ctx) - claims := validatableClaims{customValidator: p.customValidator} - if _, err := p.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + claims := p.claimsTemplate.Clone() + if _, err := p.parser.ParseWithClaims(token, claims, keyFunc); err != nil { if !errors.Is(err, jwtgo.ErrTokenSignatureInvalid) { return nil, err } @@ -116,7 +121,12 @@ func (p *Parser) Parse(ctx context.Context, token string) (*Claims, error) { return nil, err } - issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, claims.Issuer) + issuer, issuerErr := claims.GetIssuer() + if issuerErr != nil { + return nil, err // original error is more important + } + + issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, issuer) if !issuerURLFound { return nil, err } @@ -127,12 +137,18 @@ func (p *Parser) Parse(ctx context.Context, token string) (*Claims, error) { return nil, err } - if _, err = p.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + if _, err = p.parser.ParseWithClaims(token, claims, keyFunc); err != nil { return nil, err } } - return &claims.Claims, nil + if !p.skipClaimsValidation { + if err := p.customValidator(claims); err != nil { + return nil, fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidClaims, err) + } + } + + return claims, nil } func (p *Parser) getKeyFunc(ctx context.Context) func(token *jwtgo.Token) (interface{}, error) { @@ -147,13 +163,20 @@ func (p *Parser) getKeyFunc(ctx context.Context) func(token *jwtgo.Token) (inter if kid, found := token.Header["kid"]; found { kidStr = kid.(string) } - claims := token.Claims.(*validatableClaims) - if claims.Issuer == "" { - return nil, &IssuerMissingError{&claims.Claims} + claims, ok := token.Claims.(Claims) + if !ok { + return nil, fmt.Errorf("claims type %T does not implement Claims interface", token.Claims) } - issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, claims.Issuer) + issuer, issuerErr := claims.GetIssuer() + if issuerErr != nil { + return nil, issuerErr + } + if issuer == "" { + return nil, &IssuerMissingError{claims} + } + issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, issuer) if !issuerURLFound { - return nil, &IssuerUntrustedError{&claims.Claims} + return nil, &IssuerUntrustedError{claims, issuer} } return p.keysProvider.GetRSAPublicKey(ctx, issuerURL, kidStr) @@ -174,60 +197,13 @@ func (p *Parser) getURLForIssuerWithCallback(ctx context.Context, issuer string) return p.trustedIssuerNotFoundFallback(ctx, p, issuer) } -// Claims represents an extended version of JWT claims. -type Claims struct { - jwtgo.RegisteredClaims - Scope []AccessPolicy `json:"scope,omitempty"` - Version int `json:"ver,omitempty"` - UserID string `json:"uid,omitempty"` - OriginID string `json:"origin,omitempty"` - ClientID string `json:"client_id,omitempty"` - TOTPTime int64 `json:"totp_time,omitempty"` - SubType string `json:"sub_type,omitempty"` - OwnerTenantUUID string `json:"owner_tuid,omitempty"` -} - -// AccessPolicy represents a single access policy which specifies access rights to a tenant or resource -// in the scope of a resource server. -type AccessPolicy struct { - // TenantID is a unique identifier of tenant for which access is granted (if resource is not specified) - // or which the resource is owned by (if resource is specified). - TenantID string `json:"tid,omitempty"` - - // TenantUUID is a UUID of tenant for which access is granted (if the resource is not specified) - // or which the resource is owned by (if the resource is specified). - TenantUUID string `json:"tuid,omitempty"` - - // ResourceServerID is a unique resource server instance or cluster ID. - ResourceServerID string `json:"rs,omitempty"` - - // ResourceNamespace is a namespace to which resource belongs within resource server. - // E.g.: account-server, storage-manager, task-manager, alert-manager, etc. - ResourceNamespace string `json:"rn,omitempty"` - - // ResourcePath is a unique identifier of or path to a single resource or resource collection - // in the scope of the resource server and namespace. - ResourcePath string `json:"rp,omitempty"` - - // Role determines what actions are allowed to be performed on the specified tenant or resource. - Role string `json:"role,omitempty"` -} - -type validatableClaims struct { - Claims - customValidator func(c *Claims) error -} - -func (v *validatableClaims) Validate() error { - if v.customValidator != nil { - return v.customValidator(&v.Claims) - } - return nil -} - -func makeCustomAudienceValidator(requireAudience bool, audienceMatchers []audienceMatcher) func(c *Claims) error { - return func(c *Claims) error { - if len(c.Audience) == 0 { +func makeCustomAudienceValidator(requireAudience bool, audienceMatchers []audienceMatcher) func(c Claims) error { + return func(c Claims) error { + audience, err := c.GetAudience() + if err != nil { + return err + } + if len(audience) == 0 { if requireAudience { return fmt.Errorf("%w: %w", jwtgo.ErrTokenRequiredClaimMissing, &AudienceMissingError{c}) } @@ -238,12 +214,12 @@ func makeCustomAudienceValidator(requireAudience bool, audienceMatchers []audien return nil } for i := range audienceMatchers { - for j := range c.Audience { - if audienceMatchers[i](c.Audience[j]) { + for j := range audience { + if audienceMatchers[i](audience[j]) { return nil } } } - return fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidAudience, &AudienceNotExpectedError{c}) + return fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidAudience, &AudienceNotExpectedError{c, audience}) } } diff --git a/jwt/parser_test.go b/jwt/parser_test.go index fdf2a59..cb40e6f 100644 --- a/jwt/parser_test.go +++ b/jwt/parser_test.go @@ -13,6 +13,7 @@ import ( "time" jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/acronis/go-authkit/idptest" @@ -30,26 +31,41 @@ func TestJWTParser_Parse(t *testing.T) { defer issuerConfigServer.Close() t.Run("ok", func(t *testing.T) { - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }, - Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, - TOTPTime: time.Now().Unix(), - SubType: "task_manager", + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, } parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) - require.Equal(t, claims.Scope, parsedClaims.Scope) - require.Equal(t, claims.TOTPTime, parsedClaims.TOTPTime) - require.Equal(t, claims.SubType, parsedClaims.SubType) + require.Equal(t, claims, parsedClaims) + }) + + t.Run("ok, custom claims", func(t *testing.T) { + claims := &CustomClaims{ + DefaultClaims: jwt.DefaultClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + }, + CustomField: uuid.NewString(), + } + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), jwt.ParserOpts{ClaimsTemplate: &CustomClaims{}}) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) + parsedClaims, err := parser.Parse(context.Background(), tokenString) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) }) t.Run("ok for empty kid", func(t *testing.T) { - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), @@ -63,7 +79,7 @@ func TestJWTParser_Parse(t *testing.T) { }) t.Run("ok for trusted issuer url (glob pattern)", func(t *testing.T) { - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: issuerConfigServer.URL, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), @@ -86,7 +102,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("ok for expected audience (glob pattern)", func(t *testing.T) { for _, aud := range []string{"region1.cloud.com", "region2.cloud.com"} { - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Audience: []string{aud}, Issuer: issuerConfigServer.URL, @@ -112,7 +128,7 @@ func TestJWTParser_Parse(t *testing.T) { }) t.Run("unsigned jwt", func(t *testing.T) { - claims := &jwt.Claims{ + claims := jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), @@ -129,7 +145,7 @@ func TestJWTParser_Parse(t *testing.T) { }) t.Run("jwt issuer missing", func(t *testing.T) { - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{Audience: []string{"https://cloud.acronis.com"}}, } parser := jwt.NewParser(jwks.NewCachingClient()) @@ -142,7 +158,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer", func(t *testing.T) { const issuer = "untrusted-issuer" - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} parser := jwt.NewParser(jwks.NewCachingClient()) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) @@ -153,7 +169,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer url", func(t *testing.T) { const issuer = "https://3rd-party-idp.com" - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} parser := jwt.NewParser(jwks.NewCachingClient()) require.NoError(t, parser.AddTrustedIssuerURL("https://*.acronis.com")) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) @@ -165,7 +181,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer url, callback adds it to trusted", func(t *testing.T) { var callbackCallCount int - claims := &jwt.Claims{ + claims := &jwt.DefaultClaims{ RegisteredClaims: jwtgo.RegisteredClaims{ Audience: []string{issuerConfigServer.URL}, Issuer: issuerConfigServer.URL, @@ -195,40 +211,40 @@ func TestJWTParser_Parse(t *testing.T) { }) t.Run("jwt exp is missing", func(t *testing.T) { - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss}} + claims := jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss}} parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) - _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(&claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) require.ErrorIs(t, err, jwtgo.ErrTokenRequiredClaimMissing) }) t.Run("jwt expired", func(t *testing.T) { expiresAt := time.Now().Add(-time.Second) - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(expiresAt)}} + claims := jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(expiresAt)}} parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) - _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(&claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) require.ErrorIs(t, err, jwtgo.ErrTokenExpired) }) t.Run("jwt not valid yet", func(t *testing.T) { notBefore := time.Now().Add(time.Minute) - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + claims := jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), NotBefore: jwtgo.NewNumericDate(notBefore), }} parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) - _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(&claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) require.ErrorIs(t, err, jwtgo.ErrTokenNotValidYet) }) t.Run("required jwt audience is missing", func(t *testing.T) { - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }} @@ -246,7 +262,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt audience is not expected", func(t *testing.T) { const audience = "not-expected-audience" - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + claims := &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{ Audience: []string{audience}, Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), @@ -273,8 +289,8 @@ func TestJWTParser_Parse(t *testing.T) { const cacheUpdateMinInterval = time.Second - claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} - tokenString, err := idptest.MakeTokenString(claims, "737c5114f09b5ed05276bd4b520245982f7fb29f", idptest.GetTestRSAPrivateKey()) + claims := jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString, err := idptest.MakeTokenString(&claims, "737c5114f09b5ed05276bd4b520245982f7fb29f", idptest.GetTestRSAPrivateKey()) require.NoError(t, err) jwksClient := jwks.NewCachingClientWithOpts(jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) parser := jwt.NewParser(jwksClient) diff --git a/middleware.go b/middleware.go index 6a8a1c2..1b5b69f 100644 --- a/middleware.go +++ b/middleware.go @@ -49,7 +49,7 @@ const ( // JWTParser is an interface for parsing string representation of JWT. type JWTParser interface { - Parse(ctx context.Context, token string) (*jwt.Claims, error) + Parse(ctx context.Context, token string) (jwt.Claims, error) } // CachingJWTParser does the same as JWTParser but stores parsed JWT claims in cache. @@ -67,13 +67,13 @@ type jwtAuthHandler struct { next http.Handler errorDomain string jwtParser JWTParser - verifyAccess func(r *http.Request, claims *jwt.Claims) bool + verifyAccess func(r *http.Request, claims jwt.Claims) bool tokenIntrospector TokenIntrospector loggerProvider func(ctx context.Context) log.FieldLogger } type jwtAuthMiddlewareOpts struct { - verifyAccess func(r *http.Request, claims *jwt.Claims) bool + verifyAccess func(r *http.Request, claims jwt.Claims) bool tokenIntrospector TokenIntrospector loggerProvider func(ctx context.Context) log.FieldLogger } @@ -82,7 +82,7 @@ type jwtAuthMiddlewareOpts struct { type JWTAuthMiddlewareOption func(options *jwtAuthMiddlewareOpts) // WithJWTAuthMiddlewareVerifyAccess is an option to set a function that verifies access for JWTAuthMiddleware. -func WithJWTAuthMiddlewareVerifyAccess(verifyAccess func(r *http.Request, claims *jwt.Claims) bool) JWTAuthMiddlewareOption { +func WithJWTAuthMiddlewareVerifyAccess(verifyAccess func(r *http.Request, claims jwt.Claims) bool) JWTAuthMiddlewareOption { return func(options *jwtAuthMiddlewareOpts) { options.verifyAccess = verifyAccess } @@ -137,7 +137,7 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } - var jwtClaims *jwt.Claims + var jwtClaims jwt.Claims if h.tokenIntrospector != nil { if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil { switch { @@ -159,13 +159,13 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } } else { - if !introspectionResult.Active { + if !introspectionResult.IsActive() { h.logger(reqCtx).Warn("token was successfully introspected, but it is not active") apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) return } - jwtClaims = &introspectionResult.Claims + jwtClaims = introspectionResult.GetClaims() h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token was successfully introspected") }) @@ -210,17 +210,17 @@ func GetBearerTokenFromRequest(r *http.Request) string { } // NewContextWithJWTClaims creates a new context with JWT claims. -func NewContextWithJWTClaims(ctx context.Context, jwtClaims *jwt.Claims) context.Context { +func NewContextWithJWTClaims(ctx context.Context, jwtClaims jwt.Claims) context.Context { return context.WithValue(ctx, ctxKeyJWTClaims, jwtClaims) } // GetJWTClaimsFromContext extracts JWT claims from the context. -func GetJWTClaimsFromContext(ctx context.Context) *jwt.Claims { +func GetJWTClaimsFromContext(ctx context.Context) jwt.Claims { value := ctx.Value(ctxKeyJWTClaims) if value == nil { return nil } - return value.(*jwt.Claims) + return value.(jwt.Claims) } // NewContextWithBearerToken creates a new context with token. diff --git a/middleware_test.go b/middleware_test.go index 0998e58..441d939 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -23,7 +23,7 @@ import ( type mockJWTAuthMiddlewareNextHandler struct { called int - jwtClaims *jwt.Claims + jwtClaims jwt.Claims } func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { @@ -33,12 +33,12 @@ func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *h type mockJWTParser struct { parseCalled int - claimsToReturn *jwt.Claims + claimsToReturn jwt.Claims errToReturn error passedToken string } -func (p *mockJWTParser) Parse(_ context.Context, token string) (*jwt.Claims, error) { +func (p *mockJWTParser) Parse(_ context.Context, token string) (jwt.Claims, error) { p.parseCalled++ p.passedToken = token return p.claimsToReturn, p.errToReturn @@ -96,7 +96,7 @@ func TestJWTAuthMiddleware(t *testing.T) { t.Run("ok", func(t *testing.T) { const issuer = "my-idp.com" - parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) req.Header.Set(HeaderAuthorization, "Bearer a.b.c") @@ -108,12 +108,14 @@ func TestJWTAuthMiddleware(t *testing.T) { require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) require.NotNil(t, next.jwtClaims) - require.Equal(t, issuer, next.jwtClaims.Issuer) + nextIssuer, err := next.jwtClaims.GetIssuer() + require.NoError(t, err) + require.Equal(t, issuer, nextIssuer) }) t.Run("introspection failed", func(t *testing.T) { const issuer = "my-idp.com" - parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} introspector := &mockTokenIntrospector{errToReturn: errors.New("introspection failed")} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) @@ -130,7 +132,7 @@ func TestJWTAuthMiddleware(t *testing.T) { t.Run("introspection is not needed", func(t *testing.T) { const issuer = "my-idp.com" - parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenIntrospectionNotNeeded} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) @@ -143,13 +145,14 @@ func TestJWTAuthMiddleware(t *testing.T) { require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) - require.NotNil(t, next.jwtClaims) - require.Equal(t, issuer, next.jwtClaims.Issuer) + nextIssuer, err := next.jwtClaims.GetIssuer() + require.NoError(t, err) + require.Equal(t, issuer, nextIssuer) }) t.Run("ok, token is not introspectable", func(t *testing.T) { const issuer = "my-idp.com" - parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenNotIntrospectable} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) @@ -163,13 +166,15 @@ func TestJWTAuthMiddleware(t *testing.T) { require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) require.NotNil(t, next.jwtClaims) - require.Equal(t, issuer, next.jwtClaims.Issuer) + nextIssuer, err := next.jwtClaims.GetIssuer() + require.NoError(t, err) + require.Equal(t, issuer, nextIssuer) }) t.Run("authentication failed, token is introspected but inactive", func(t *testing.T) { const issuer = "my-idp.com" parser := &mockJWTParser{} - introspector := &mockTokenIntrospector{resultToReturn: idptoken.IntrospectionResult{Active: false}} + introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) req.Header.Set(HeaderAuthorization, "Bearer a.b.c") @@ -186,7 +191,8 @@ func TestJWTAuthMiddleware(t *testing.T) { t.Run("ok, token is introspected and active", func(t *testing.T) { const issuer = "my-idp.com" parser := &mockJWTParser{} - introspector := &mockTokenIntrospector{resultToReturn: idptoken.IntrospectionResult{Active: true, Claims: jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}} + introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{ + Active: true, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) req.Header.Set(HeaderAuthorization, "Bearer a.b.c") @@ -199,7 +205,9 @@ func TestJWTAuthMiddleware(t *testing.T) { require.Equal(t, 0, parser.parseCalled) require.Equal(t, 1, next.called) require.NotNil(t, next.jwtClaims) - require.Equal(t, issuer, next.jwtClaims.Issuer) + nextIssuer, err := next.jwtClaims.GetIssuer() + require.NoError(t, err) + require.Equal(t, issuer, nextIssuer) }) } @@ -207,7 +215,7 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { const errDomain = "TestDomain" t.Run("authorization failed", func(t *testing.T) { - parser := &mockJWTParser{claimsToReturn: &jwt.Claims{}} + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) req.Header.Set(HeaderAuthorization, "Bearer a.b.c") @@ -224,7 +232,7 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { t.Run("ok", func(t *testing.T) { scope := []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}} - parser := &mockJWTParser{claimsToReturn: &jwt.Claims{Scope: scope}} + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{Scope: scope}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) req.Header.Set(HeaderAuthorization, "Bearer a.b.c") @@ -237,6 +245,6 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) require.NotNil(t, next.jwtClaims) - require.EqualValues(t, scope, next.jwtClaims.Scope) + require.EqualValues(t, scope, next.jwtClaims.GetScope()) }) }