diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9ed0094 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea/ +.vscode/ +vendor/ \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..acf3fe4 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,78 @@ +linters-settings: + gocyclo: + min-complexity: 25 + goconst: + min-len: 2 + min-occurrences: 2 + misspell: + locale: US + lll: + line-length: 140 + goimports: + local-prefixes: github.com/acronis/go-authkit/ + gocritic: + enabled-tags: + - diagnostic + - performance + - style + - experimental + disabled-checks: + - whyNoLint + - paramTypeCombine + - sloppyReassign + settings: + hugeParam: + sizeThreshold: 256 + rangeValCopy: + sizeThreshold: 256 + funlen: + lines: 120 + statements: 60 + +linters: + disable-all: true + enable: + - bodyclose + - dogsled + - errcheck + - exportloopref + - funlen + - gochecknoinits + - goconst + - gocritic + - gocyclo + - gofmt + - goimports + - gosec + - gosimple + - govet + - ineffassign + - lll + - misspell + - nakedret + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - whitespace + +issues: + # Don't use default excluding to be sure all exported things (method, functions, consts and so on) have comments. + exclude-use-default: false + exclude-rules: + - path: _test\.go + linters: + - dogsled + - ineffassign + - funlen + - gocritic + - gocyclo + - gosec + - goconst + - govet + - lll + - staticcheck + - unused + - unparam diff --git a/.trufflehog3.yml b/.trufflehog3.yml new file mode 100644 index 0000000..b02d531 --- /dev/null +++ b/.trufflehog3.yml @@ -0,0 +1,9 @@ +exclude: + - message: Exclude values in test files and primitives for testing + paths: + - idptest/jwks_handler.go + - jwt/jwt_test.go + + - message: Skip false positive high-entropy sequences + id: high-entropy + pattern: cfgKeyIntrospectionGRPCTLSEnabled diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..027c494 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @vasayxtx @MikeYast \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..807b16f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright © 2024 Acronis International GmbH. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..3b16c7b --- /dev/null +++ b/README.md @@ -0,0 +1,137 @@ +# Simple library in Go with primitives for performing authentication and authorization + +The library includes the following packages: ++ `auth` (root directory) - provides authentication and authorization primitives for using on the server side. ++ `jwt` - provides parser for JSON Web Tokens (JWT). ++ `jwks` - provides a client for fetching and caching JSON Web Key Sets (JWKS). ++ `idptoken` - provides a client for fetching and caching Access Tokens from Identity Providers (IDP). ++ `idptest` - provides primitives for testing IDP clients. + +## Examples + +### Authenticating requests with JWT tokens + +The `JWTAuthMiddleware` function creates a middleware that authenticates requests with JWT tokens. + +It uses the `JWTParser` to parse and validate JWT. +`JWTParser` can verify JWT tokens signed with RSA (RS256, RS384, RS512) algorithms for now. +It performs /.well-known/openid-configuration request to get the JWKS URL ("jwks_uri" field) and fetches JWKS from there. +For other algorithms `jwt.SignAlgUnknownError` error will be returned. +The `JWTParser` can be created with the `NewJWTParser` function or with the `NewJWTParserWithCachingJWKS` function. +The last one is recommended for production use because it caches public keys (JWKS) that are used for verifying JWT tokens. + +See `Config` struct for more customization options. + +Example: + +```go +package main + +import ( + "net/http" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-authkit" +) + +func main() { + jwtConfig := auth.JWTConfig{ + TrustedIssuerURLs: []string{"https://my-idp.com"}, + //TrustedIssuers: map[string]string{"my-idp": "https://my-idp.com"}, // Use TrustedIssuers if you have a custom issuer name. + } + jwtParser, _ := auth.NewJWTParserWithCachingJWKS(&auth.Config{JWT: jwtConfig}, log.NewDisabledLogger()) + authN := auth.JWTAuthMiddleware("MyService", jwtParser) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authN(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + //jwtClaims := GetJWTClaimsFromContext(r.Context()) // GetJWTClaimsFromContext is a helper function to get JWT claims from context. + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + _ = http.ListenAndServe(":8080", srvMux) +} +``` + +```shell +$ curl -w "\nHTTP code: %{http_code}\n" localhost:8080 +Hello, World! +HTTP code: 200 + +$ curl -w "\nHTTP code: %{http_code}\n" localhost:8080/admin +{"error":{"domain":"MyService","code":"bearerTokenMissing","message":"Authorization bearer token is missing."}} +HTTP code: 401 +``` + +### Authorizing requests with JWT tokens + +```go +package main + +import ( + "net/http" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-authkit" +) + +func main() { + jwtConfig := auth.JWTConfig{TrustedIssuers: map[string]string{"my-idp": idpURL}} + jwtParser, _ := auth.NewJWTParserWithCachingJWKS(&auth.Config{JWT: jwtConfig}, log.NewDisabledLogger()) + authOnlyAdmin := auth.JWTAuthMiddlewareWithVerifyAccess("MyService", jwtParser, + auth.NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"})) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authOnlyAdmin(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + _ = http.ListenAndServe(":8080", srvMux) +} +``` + +Please see [example_test.go](./example_test.go) for a full version of the example. + +### Fetching and caching Access Tokens from Identity Providers + +The `idptoken.Provider` object is used to fetch and cache Access Tokens from Identity Providers (IDP). + +Example: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/acronis/go-authkit/idptoken" +) + +func main() { + // ... + httpClient := &http.Client{Timeout: 30 * time.Second} + source := idptoken.Source{ + URL: idpURL, + ClientID: clientID, + ClientSecret: clientSecret, + } + provider := idptoken.NewProvider(httpClient, source) + accessToken, err := provider.GetToken() + if err != nil { + log.Fatalf("failed to get access token: %v", err) + } + // ... +} +``` + +## License + +Copyright © 2024 Acronis International GmbH. + +Licensed under [MIT License](./LICENSE). diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..8737a9a --- /dev/null +++ b/auth.go @@ -0,0 +1,312 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "time" + + "github.com/acronis/go-appkit/log" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +// Default values. +const ( + DefaultHTTPClientRequestTimeout = time.Second * 30 + DefaultGRPCClientRequestTimeout = time.Second * 30 +) + +// NewJWTParser creates a new JWTParser with the given configuration. +// If cfg.JWT.ClaimsCache.Enabled is true, then jwt.CachingParser created, otherwise - jwt.Parser. +func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { + var options jwtParserOptions + for _, opt := range opts { + opt(&options) + } + logger := options.logger + if logger == nil { + logger = log.NewDisabledLogger() + } + + // Make caching JWKS client. + jwksCacheUpdateMinInterval := cfg.JWKS.Cache.UpdateMinInterval + if jwksCacheUpdateMinInterval == 0 { + jwksCacheUpdateMinInterval = jwks.DefaultCacheUpdateMinInterval + } + httpClientRequestTimeout := cfg.HTTPClient.RequestTimeout + if httpClientRequestTimeout == 0 { + httpClientRequestTimeout = DefaultHTTPClientRequestTimeout + } + jwksClientOpts := jwks.CachingClientOpts{ + ClientOpts: jwks.ClientOpts{PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel}, + CacheUpdateMinInterval: jwksCacheUpdateMinInterval, + } + jwksClient := jwks.NewCachingClientWithOpts(&http.Client{Timeout: httpClientRequestTimeout}, logger, jwksClientOpts) + + // Make JWT parser. + + if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { + logger.Warn("list of trusted issuers is empty, jwt parsing may not work properly") + } + + parserOpts := jwt.ParserOpts{ + RequireAudience: cfg.JWT.RequireAudience, + ExpectedAudience: cfg.JWT.ExpectedAudience, + TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, + } + + if cfg.JWT.ClaimsCache.Enabled { + cachingJWTParser, err := jwt.NewCachingParserWithOpts(jwksClient, logger, jwt.CachingParserOpts{ + ParserOpts: parserOpts, + CacheMaxEntries: cfg.JWT.ClaimsCache.MaxEntries, + }) + if err != nil { + return nil, fmt.Errorf("new caching JWT parser: %w", err) + } + if err = addTrustedIssuers(cachingJWTParser, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return cachingJWTParser, nil + } + + jwtParser := jwt.NewParserWithOpts(jwksClient, logger, parserOpts) + if err := addTrustedIssuers(jwtParser, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return jwtParser, nil +} + +type jwtParserOptions struct { + logger log.FieldLogger + prometheusLibInstanceLabel string + trustedIssuerNotFoundFallback jwt.TrustedIssNotFoundFallback +} + +// JWTParserOption is an option for creating JWTParser. +type JWTParserOption func(options *jwtParserOptions) + +// WithJWTParserLogger sets the logger for JWTParser. +func WithJWTParserLogger(logger log.FieldLogger) JWTParserOption { + return func(options *jwtParserOptions) { + options.logger = logger + } +} + +// WithJWTParserPrometheusLibInstanceLabel sets the Prometheus lib instance label for JWTParser. +func WithJWTParserPrometheusLibInstanceLabel(label string) JWTParserOption { + return func(options *jwtParserOptions) { + options.prometheusLibInstanceLabel = label + } +} + +// WithJWTParserTrustedIssuerNotFoundFallback sets the fallback for JWTParser when trusted issuer is not found. +func WithJWTParserTrustedIssuerNotFoundFallback(fallback jwt.TrustedIssNotFoundFallback) JWTParserOption { + return func(options *jwtParserOptions) { + options.trustedIssuerNotFoundFallback = fallback + } +} + +// NewTokenIntrospector creates a new TokenIntrospector with the given configuration, token provider and scope filter. +// If cfg.Introspection.ClaimsCache.Enabled or cfg.Introspection.NegativeCache.Enabled is true, +// then idptoken.CachingIntrospector created, otherwise - idptoken.Introspector. +// Please note that the tokenProvider should be able to provide access token with the policy for introspection. +// scopeFilter is a list of filters that will be applied to the introspected token. +func NewTokenIntrospector( + cfg *Config, + tokenProvider idptoken.IntrospectionTokenProvider, + scopeFilter []idptoken.IntrospectionScopeFilterAccessPolicy, + opts ...TokenIntrospectorOption, +) (TokenIntrospector, error) { + var options tokenIntrospectorOptions + for _, opt := range opts { + opt(&options) + } + logger := options.logger + if logger == nil { + logger = log.NewDisabledLogger() + } + + if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { + logger.Warn("list of trusted issuers is empty, jwt introspection may not work properly") + } + + var grpcClient *idptoken.GRPCClient + if cfg.Introspection.GRPC.Target != "" { + transportCreds, err := makeGRPCTransportCredentials(cfg.Introspection.GRPC.TLS) + if err != nil { + return nil, fmt.Errorf("make grpc transport credentials: %w", err) + } + grpcClient, err = idptoken.NewGRPCClientWithOpts(cfg.Introspection.GRPC.Target, transportCreds, + idptoken.GRPCClientOpts{RequestTimeout: cfg.GRPCClient.RequestTimeout, Logger: logger}) + if err != nil { + return nil, fmt.Errorf("new grpc client: %w", err) + } + } + + httpClientRequestTimeout := cfg.HTTPClient.RequestTimeout + if httpClientRequestTimeout == 0 { + httpClientRequestTimeout = DefaultHTTPClientRequestTimeout + } + + introspectorOpts := idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: cfg.Introspection.Endpoint, + GRPCClient: grpcClient, + HTTPClient: &http.Client{Timeout: httpClientRequestTimeout}, + AccessTokenScope: cfg.Introspection.AccessTokenScope, + Logger: logger, + MinJWTVersion: cfg.Introspection.MinJWTVersion, + ScopeFilter: scopeFilter, + TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, + PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel, + } + + if cfg.Introspection.ClaimsCache.Enabled || cfg.Introspection.NegativeCache.Enabled { + cachingIntrospector, err := idptoken.NewCachingIntrospectorWithOpts(tokenProvider, idptoken.CachingIntrospectorOpts{ + IntrospectorOpts: introspectorOpts, + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{ + Enabled: cfg.Introspection.ClaimsCache.Enabled, + MaxEntries: cfg.Introspection.ClaimsCache.MaxEntries, + TTL: cfg.Introspection.ClaimsCache.TTL, + }, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{ + Enabled: cfg.Introspection.NegativeCache.Enabled, + MaxEntries: cfg.Introspection.NegativeCache.MaxEntries, + TTL: cfg.Introspection.NegativeCache.TTL, + }, + }) + if err != nil { + return nil, fmt.Errorf("new caching introspector: %w", err) + } + if err = addTrustedIssuers(cachingIntrospector, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return cachingIntrospector, nil + } + + introspector := idptoken.NewIntrospectorWithOpts(tokenProvider, introspectorOpts) + if err := addTrustedIssuers(introspector, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return introspector, nil +} + +type tokenIntrospectorOptions struct { + logger log.FieldLogger + prometheusLibInstanceLabel string + trustedIssuerNotFoundFallback idptoken.TrustedIssNotFoundFallback +} + +// TokenIntrospectorOption is an option for creating TokenIntrospector. +type TokenIntrospectorOption func(options *tokenIntrospectorOptions) + +// WithTokenIntrospectorLogger sets the logger for TokenIntrospector. +func WithTokenIntrospectorLogger(logger log.FieldLogger) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.logger = logger + } +} + +// WithTokenIntrospectorPrometheusLibInstanceLabel sets the Prometheus lib instance label for TokenIntrospector. +func WithTokenIntrospectorPrometheusLibInstanceLabel(label string) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.prometheusLibInstanceLabel = label + } +} + +// WithTokenIntrospectorTrustedIssuerNotFoundFallback sets the fallback for TokenIntrospector +// when trusted issuer is not found. +func WithTokenIntrospectorTrustedIssuerNotFoundFallback( + fallback idptoken.TrustedIssNotFoundFallback, +) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.trustedIssuerNotFoundFallback = fallback + } +} + +// Role is a representation of role which may be used for verifying access. +type Role struct { + Namespace string + Name string +} + +// NewVerifyAccessByRolesInJWT creates a new function which may be used for verifying access by roles in JWT scope. +func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims *jwt.Claims) bool { + return func(_ *http.Request, claims *jwt.Claims) bool { + for i := range roles { + for j := range claims.Scope { + if roles[i].Name == claims.Scope[j].Role && roles[i].Namespace == claims.Scope[j].ResourceNamespace { + return true + } + } + } + return false + } +} + +// NewVerifyAccessByRolesInJWTMaker creates a new function which may be used for verifying access by roles in JWT scope given a namespace. +func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool { + return func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool { + roles := make([]Role, 0, len(roleNames)) + for i := range roleNames { + roles = append(roles, Role{Namespace: namespace, Name: roleNames[i]}) + } + return NewVerifyAccessByRolesInJWT(roles...) + } +} + +type issuerParser interface { + AddTrustedIssuer(issName string, issURL string) + AddTrustedIssuerURL(issURL string) error +} + +func addTrustedIssuers(issParser issuerParser, issuers map[string]string, issuerURLs []string) error { + for issName, issURL := range issuers { + issParser.AddTrustedIssuer(issName, issURL) + } + for _, issURL := range issuerURLs { + if err := issParser.AddTrustedIssuerURL(issURL); err != nil { + return fmt.Errorf("add trusted issuer URL: %w", err) + } + } + return nil +} + +func makeGRPCTransportCredentials(tlsCfg GRPCTLSConfig) (credentials.TransportCredentials, error) { + if !tlsCfg.Enabled { + return insecure.NewCredentials(), nil + } + + config := &tls.Config{} // nolint: gosec // TLS 1.2 is used by default. + if tlsCfg.CACert != "" { + caCert, err := os.ReadFile(tlsCfg.CACert) + if err != nil { + return nil, fmt.Errorf("read CA's certificate: %w", err) + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to add CA's certificate") + } + config.RootCAs = certPool + } + if tlsCfg.ClientCert != "" && tlsCfg.ClientKey != "" { + clientCert, err := tls.LoadX509KeyPair(tlsCfg.ClientCert, tlsCfg.ClientKey) + if err != nil { + return nil, fmt.Errorf("load client's certificate and key: %w", err) + } + config.Certificates = []tls.Certificate{clientCert} + } + return credentials.NewTLS(config), nil +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..f327704 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,411 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + gotesting "testing" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/testing" + "github.com/acronis/go-authkit/jwt" +) + +func TestNewJWTParser(t *gotesting.T) { + const testIss = "test-issuer" + + idpSrv := idptest.NewHTTPServer() + require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = idpSrv.Shutdown(context.Background()) }() + + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: idpSrv.URL(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "ro_admin"}}, + } + token := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + claimsWithNamedIssuer := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}}, + } + tokenWithNamedIssuer := idptest.MustMakeTokenStringSignedWithTestKey(claimsWithNamedIssuer) + + tests := []struct { + name string + token string + cfg *Config + expectedClaims *jwt.Claims + checkFn func(t *gotesting.T, jwtParser JWTParser) + }{ + { + name: "new jwt parser, trusted issuers map", + cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: idpSrv.URL()}}}, + token: tokenWithNamedIssuer, + expectedClaims: claimsWithNamedIssuer, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.Parser{}, jwtParser) + }, + }, + { + name: "new jwt parser, trusted issuer urls", + cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{idpSrv.URL()}}}, + token: token, + expectedClaims: claims, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.Parser{}, jwtParser) + }, + }, + { + name: "new caching jwt parser, trusted issuers map", + cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: idpSrv.URL()}, ClaimsCache: ClaimsCacheConfig{Enabled: true}}}, + token: tokenWithNamedIssuer, + expectedClaims: claimsWithNamedIssuer, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.CachingParser{}, jwtParser) + cachingParser := jwtParser.(*jwt.CachingParser) + require.Equal(t, 1, cachingParser.ClaimsCache.Len()) + }, + }, + { + name: "new caching jwt parser, trusted issuer urls", + cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{idpSrv.URL()}, ClaimsCache: ClaimsCacheConfig{Enabled: true}}}, + token: token, + expectedClaims: claims, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.CachingParser{}, jwtParser) + cachingParser := jwtParser.(*jwt.CachingParser) + require.Equal(t, 1, cachingParser.ClaimsCache.Len()) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + jwtParser, err := NewJWTParser(tt.cfg) + require.NoError(t, err) + + parsedClaims, err := jwtParser.Parse(context.Background(), tt.token) + require.NoError(t, err) + require.Equal(t, tt.expectedClaims, parsedClaims) + if tt.checkFn != nil { + tt.checkFn(t, jwtParser) + } + }) + } +} + +func TestNewTokenIntrospector(t *gotesting.T) { + const testIss = "test-issuer" + + httpServerIntrospector := testing.NewHTTPServerTokenIntrospectorMock() + grpcServerIntrospector := testing.NewGRPCServerTokenIntrospectorMock() + + // Start testing HTTP IDP server. + httpIDPSrv := idptest.NewHTTPServer(idptest.WithHTTPTokenIntrospector(httpServerIntrospector)) + require.NoError(t, httpIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = httpIDPSrv.Shutdown(context.Background()) }() + + // Generate a self-signed certificate for the testing gRPC IDP server and start it. + tlsCert, certPEM, _ := generateSelfSignedRSACert(t) + certFile := filepath.Join(t.TempDir(), "cert.pem") + require.NoError(t, os.WriteFile(certFile, certPEM, 0644)) + grpcIDPSrv := idptest.NewGRPCServer( + idptest.WithGRPCTokenIntrospector(grpcServerIntrospector), + idptest.WithGRPCServerOptions(grpc.Creds(credentials.NewServerTLSFromCert(&tlsCert)))) + require.NoError(t, grpcIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { grpcIDPSrv.GracefulStop() }() + + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "ro_admin"}}, + } + token := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + claimsWithNamedIssuer := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}}, + } + tokenWithNamedIssuer := idptest.MustMakeTokenStringSignedWithTestKey(claimsWithNamedIssuer) + + opaqueToken := "opaque-token-" + uuid.NewString() + opaqueTokenScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) + grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{ + Active: true, TokenType: idptoken.TokenTypeBearer, Scope: []*pb.AccessTokenScope{ + { + TenantUuid: opaqueTokenScope[0].TenantUUID, + ResourceNamespace: opaqueTokenScope[0].ResourceNamespace, + RoleName: opaqueTokenScope[0].Role, + ResourcePath: opaqueTokenScope[0].ResourcePath, + }, + }}) + + tests := []struct { + name string + cfg *Config + token string + expectedResult idptoken.IntrospectionResult + checkFn func(t *gotesting.T, introspector TokenIntrospector) + }{ + { + name: "new token introspector, dynamic endpoint, trusted issuers map", + cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}}, + token: tokenWithNamedIssuer, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claimsWithNamedIssuer, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.Introspector{}, introspector) + }, + }, + { + name: "new token introspector, dynamic endpoint, trusted issuer urls", + cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}}, + token: token, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claims, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.Introspector{}, introspector) + }, + }, + { + name: "new caching token introspector, dynamic endpoint, trusted issuers map", + cfg: &Config{ + JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: httpIDPSrv.URL()}}, + Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}}, + }, + token: tokenWithNamedIssuer, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claimsWithNamedIssuer, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.CachingIntrospector{}, introspector) + cachingIntrospector := introspector.(*idptoken.CachingIntrospector) + require.Equal(t, 1, cachingIntrospector.ClaimsCache.Len(context.Background())) + }, + }, + { + name: "new caching token introspector, dynamic endpoint, trusted issuer urls", + cfg: &Config{ + JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, + Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}}, + }, + token: token, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claims, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.CachingIntrospector{}, introspector) + cachingIntrospector := introspector.(*idptoken.CachingIntrospector) + require.Equal(t, 1, cachingIntrospector.ClaimsCache.Len(context.Background())) + }, + }, + { + name: "new caching token introspector, static http endpoint", + cfg: &Config{ + Introspection: IntrospectionConfig{ + Enabled: true, + ClaimsCache: IntrospectionCacheConfig{Enabled: true}, + Endpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.CachingIntrospector{}, introspector) + cachingIntrospector := introspector.(*idptoken.CachingIntrospector) + require.Equal(t, 1, cachingIntrospector.ClaimsCache.Len(context.Background())) + }, + }, + { + name: "new token introspector, gRPC target, tls enabled", + cfg: &Config{ + JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, + Introspection: IntrospectionConfig{ + Enabled: true, + GRPC: IntrospectionGRPCConfig{ + Target: grpcIDPSrv.Addr(), + TLS: GRPCTLSConfig{ + Enabled: true, + CACert: certFile, + }, + }, + Endpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.Introspector{}, introspector) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + jwtParser, err := NewJWTParser(tt.cfg) + require.NoError(t, err) + httpServerIntrospector.JWTParser = jwtParser + grpcServerIntrospector.JWTParser = jwtParser + + introspector, err := NewTokenIntrospector(tt.cfg, idptest.NewSimpleTokenProvider("access-token"), nil) + require.NoError(t, err) + + result, err := introspector.IntrospectToken(context.Background(), tt.token) + require.NoError(t, err) + require.Equal(t, tt.expectedResult, result) + if tt.checkFn != nil { + tt.checkFn(t, introspector) + } + }) + } +} + +func TestNewVerifyAccessByJWTRoles(t *gotesting.T) { + jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{ + {ResourceNamespace: "policy_manager", Role: "admin"}, + {ResourceNamespace: "scan_service", Role: "admin"}, + {Role: "backup_user"}, + {ResourceNamespace: "agent_manager", Role: "agent_viewer"}, + }} + cases := []struct { + roles []Role + want bool + }{ + {[]Role{{Name: "tenant_viewer"}}, false}, + {[]Role{{Name: "backup_user"}}, true}, + {[]Role{{Namespace: "alert_manager", Name: "admin"}}, false}, + {[]Role{{Namespace: "policy_manager", Name: "admin"}}, true}, + {[]Role{{Namespace: "alert_manager", Name: "admin"}, {Name: "tenant_viewer"}}, false}, + {[]Role{{Namespace: "alert_manager", Name: "admin"}, {Name: "tenant_viewer"}, {Namespace: "policy_manager", Name: "admin"}}, true}, + } + for _, c := range cases { + got := NewVerifyAccessByRolesInJWT(c.roles...)(httptest.NewRequest(http.MethodGet, "/", nil), jwtClaims) + require.Equal(t, c.want, got, "want %v, got %v, roles %+v", c.want, got, c.roles) + } +} + +func TestNewVerifyAccessByJWTRolesMaker(t *gotesting.T) { + jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{ + {ResourceNamespace: "policy_manager", Role: "admin"}, + {ResourceNamespace: "scan_service", Role: "admin"}, + {Role: "backup_user"}, + {ResourceNamespace: "agent_manager", Role: "agent_viewer"}, + {ResourceNamespace: "agent_manager", Role: "agent_registrar"}, + }} + cases := []struct { + roleNamespace string + roleNames []string + want bool + }{ + {"agent_manager", []string{"agent_viewer", "agent_registrar"}, true}, + {"policy_manager", []string{"admin"}, true}, + {"", []string{"backup_user"}, true}, + {"alert_manager", []string{"admin"}, false}, + } + + for _, c := range cases { + got := NewVerifyAccessByRolesInJWTMaker(c.roleNamespace)(c.roleNames...)(httptest.NewRequest(http.MethodGet, "/", nil), jwtClaims) + require.Equal(t, c.want, got, "want %v, got %v, roleNamespace: %v, roleNames %+v", c.want, got, c.roleNamespace, c.roleNames) + } +} + +func generateSelfSignedRSACert(t *gotesting.T) (tlsCert tls.Certificate, certPEM []byte, keyPEM []byte) { + t.Helper() + + // Create a private key + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // Create a certificate template + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) // 1 year + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Organization"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + // Create the certificate + certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + // PEM encode the certificate and private key + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + // Load the certificate and key as tls.Certificate + tlsCert, err = tls.X509KeyPair(certPEM, keyPEM) + require.NoError(t, err) + + return tlsCert, certPEM, keyPEM +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..64c4366 --- /dev/null +++ b/config.go @@ -0,0 +1,296 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "fmt" + "net/url" + "time" + + "github.com/acronis/go-appkit/config" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +const ( + cfgKeyHTTPClientRequestTimeout = "auth.httpClient.requestTimeout" + cfgKeyGRPCClientRequestTimeout = "auth.grpcClient.requestTimeout" + cfgKeyJWTTrustedIssuers = "auth.jwt.trustedIssuers" + cfgKeyJWTTrustedIssuerURLs = "auth.jwt.trustedIssuerUrls" + cfgKeyJWTRequireAudience = "auth.jwt.requireAudience" + cfgKeyJWTExceptedAudience = "auth.jwt.expectedAudience" + cfgKeyJWTClaimsCacheEnabled = "auth.jwt.claimsCache.enabled" + cfgKeyJWTClaimsCacheMaxEntries = "auth.jwt.claimsCache.maxEntries" + cfgKeyJWKSCacheUpdateMinInterval = "auth.jwks.cache.updateMinInterval" + cfgKeyIntrospectionEnabled = "auth.introspection.enabled" + cfgKeyIntrospectionEndpoint = "auth.introspection.endpoint" + cfgKeyIntrospectionGRPCTarget = "auth.introspection.grpc.target" + cfgKeyIntrospectionGRPCTLSEnabled = "auth.introspection.grpc.tls.enabled" + cfgKeyIntrospectionGRPCTLSCACert = "auth.introspection.grpc.tls.caCert" + cfgKeyIntrospectionGRPCTLSClientCert = "auth.introspection.grpc.tls.clientCert" + cfgKeyIntrospectionGRPCTLSClientKey = "auth.introspection.grpc.tls.clientKey" + cfgKeyIntrospectionAccessTokenScope = "auth.introspection.accessTokenScope" // nolint:gosec // false positive + cfgKeyIntrospectionMinJWTVer = "auth.introspection.minJWTVersion" + cfgKeyIntrospectionClaimsCacheEnabled = "auth.introspection.claimsCache.enabled" + cfgKeyIntrospectionClaimsCacheMaxEntries = "auth.introspection.claimsCache.maxEntries" + cfgKeyIntrospectionClaimsCacheTTL = "auth.introspection.claimsCache.ttl" + cfgKeyIntrospectionNegativeCacheEnabled = "auth.introspection.negativeCache.enabled" + cfgKeyIntrospectionNegativeCacheMaxEntries = "auth.introspection.negativeCache.maxEntries" + cfgKeyIntrospectionNegativeCacheTTL = "auth.introspection.negativeCache.ttl" +) + +// JWTConfig is configuration of how JWT will be verified. +type JWTConfig struct { + TrustedIssuers map[string]string + TrustedIssuerURLs []string + RequireAudience bool + ExpectedAudience []string + ClaimsCache ClaimsCacheConfig +} + +// JWKSConfig is configuration of how JWKS will be used. +type JWKSConfig struct { + Cache struct { + UpdateMinInterval time.Duration + } +} + +// IntrospectionConfig is a configuration of how token introspection will be used. +type IntrospectionConfig struct { + Enabled bool + + Endpoint string + AccessTokenScope []string + + // MinJWTVersion is a minimum version of JWT that will be accepted for introspection. + // NOTE: it's a temporary solution for determining whether introspection is needed or not, + // and it will be removed in the future. + MinJWTVersion int + + ClaimsCache IntrospectionCacheConfig + NegativeCache IntrospectionCacheConfig + + GRPC IntrospectionGRPCConfig +} + +// ClaimsCacheConfig is a configuration of how claims cache will be used. +type ClaimsCacheConfig struct { + Enabled bool + MaxEntries int +} + +// IntrospectionCacheConfig is a configuration of how claims cache will be used for introspection. +type IntrospectionCacheConfig struct { + Enabled bool + MaxEntries int + TTL time.Duration +} + +// IntrospectionGRPCConfig is a configuration of how token will be introspected via gRPC. +type IntrospectionGRPCConfig struct { + Target string + RequestTimeout time.Duration + TLS GRPCTLSConfig +} + +// GRPCTLSConfig is a configuration of how gRPC connection will be secured. +type GRPCTLSConfig struct { + Enabled bool + CACert string + ClientCert string + ClientKey string +} + +type HTTPClientConfig struct { + RequestTimeout time.Duration +} + +type GRPCClientConfig struct { + RequestTimeout time.Duration +} + +// Config represents a set of configuration parameters for authentication and authorization. +type Config struct { + HTTPClient HTTPClientConfig + GRPCClient GRPCClientConfig + + JWT JWTConfig + JWKS JWKSConfig + Introspection IntrospectionConfig + + keyPrefix string +} + +var _ config.Config = (*Config)(nil) +var _ config.KeyPrefixProvider = (*Config)(nil) + +// NewConfig creates a new instance of the Config. +func NewConfig() *Config { + return NewConfigWithKeyPrefix("") +} + +// NewConfigWithKeyPrefix creates a new instance of the Config. +// Allows specifying key prefix which will be used for parsing configuration parameters. +func NewConfigWithKeyPrefix(keyPrefix string) *Config { + return &Config{keyPrefix: keyPrefix} +} + +// KeyPrefix returns a key prefix with which all configuration parameters should be presented. +func (c *Config) KeyPrefix() string { + return c.keyPrefix +} + +// SetProviderDefaults sets default configuration values for auth in config.DataProvider. +func (c *Config) SetProviderDefaults(dp config.DataProvider) { + dp.SetDefault(cfgKeyHTTPClientRequestTimeout, DefaultHTTPClientRequestTimeout.String()) + dp.SetDefault(cfgKeyGRPCClientRequestTimeout, DefaultGRPCClientRequestTimeout.String()) + dp.SetDefault(cfgKeyJWTClaimsCacheMaxEntries, jwt.DefaultClaimsCacheMaxEntries) + dp.SetDefault(cfgKeyJWKSCacheUpdateMinInterval, jwks.DefaultCacheUpdateMinInterval.String()) + dp.SetDefault(cfgKeyIntrospectionMinJWTVer, idptoken.MinJWTVersionForIntrospection) + dp.SetDefault(cfgKeyIntrospectionClaimsCacheMaxEntries, idptoken.DefaultIntrospectionClaimsCacheMaxEntries) + dp.SetDefault(cfgKeyIntrospectionClaimsCacheTTL, idptoken.DefaultIntrospectionClaimsCacheTTL.String()) + dp.SetDefault(cfgKeyIntrospectionNegativeCacheMaxEntries, idptoken.DefaultIntrospectionNegativeCacheMaxEntries) + dp.SetDefault(cfgKeyIntrospectionNegativeCacheTTL, idptoken.DefaultIntrospectionNegativeCacheTTL.String()) +} + +// Set sets auth configuration values from config.DataProvider. +func (c *Config) Set(dp config.DataProvider) error { + var err error + + if c.HTTPClient.RequestTimeout, err = dp.GetDuration(cfgKeyHTTPClientRequestTimeout); err != nil { + return err + } + if c.GRPCClient.RequestTimeout, err = dp.GetDuration(cfgKeyGRPCClientRequestTimeout); err != nil { + return err + } + if err = c.setJWTConfig(dp); err != nil { + return err + } + if err = c.setJWKSConfig(dp); err != nil { + return err + } + if err = c.setIntrospectionConfig(dp); err != nil { + return err + } + + return nil +} + +func (c *Config) setJWTConfig(dp config.DataProvider) error { + var err error + + if c.JWT.TrustedIssuers, err = dp.GetStringMapString(cfgKeyJWTTrustedIssuers); err != nil { + return err + } + if c.JWT.TrustedIssuerURLs, err = dp.GetStringSlice(cfgKeyJWTTrustedIssuerURLs); err != nil { + return err + } + for _, issURL := range c.JWT.TrustedIssuerURLs { + if _, err = url.Parse(issURL); err != nil { + return dp.WrapKeyErr(cfgKeyJWTTrustedIssuerURLs, err) + } + } + if c.JWT.RequireAudience, err = dp.GetBool(cfgKeyJWTRequireAudience); err != nil { + return err + } + if c.JWT.ExpectedAudience, err = dp.GetStringSlice(cfgKeyJWTExceptedAudience); err != nil { + return err + } + if c.JWT.ClaimsCache.Enabled, err = dp.GetBool(cfgKeyJWTClaimsCacheEnabled); err != nil { + return err + } + if c.JWT.ClaimsCache.MaxEntries, err = dp.GetInt(cfgKeyJWTClaimsCacheMaxEntries); err != nil { + return err + } + if c.JWT.ClaimsCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyJWTClaimsCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + + return nil +} + +func (c *Config) setJWKSConfig(dp config.DataProvider) error { + var err error + if c.JWKS.Cache.UpdateMinInterval, err = dp.GetDuration(cfgKeyJWKSCacheUpdateMinInterval); err != nil { + return err + } + return nil +} + +func (c *Config) setIntrospectionConfig(dp config.DataProvider) error { + var err error + + if c.Introspection.Enabled, err = dp.GetBool(cfgKeyIntrospectionEnabled); err != nil { + return err + } + if c.Introspection.Endpoint, err = dp.GetString(cfgKeyIntrospectionEndpoint); err != nil { + return err + } + if _, err = url.Parse(c.Introspection.Endpoint); err != nil { + return dp.WrapKeyErr(cfgKeyIntrospectionEndpoint, err) + } + + // GRPC + if c.Introspection.GRPC.Target, err = dp.GetString(cfgKeyIntrospectionGRPCTarget); err != nil { + return err + } + if c.Introspection.GRPC.TLS.Enabled, err = dp.GetBool(cfgKeyIntrospectionGRPCTLSEnabled); err != nil { + return err + } + if c.Introspection.GRPC.TLS.CACert, err = dp.GetString(cfgKeyIntrospectionGRPCTLSCACert); err != nil { + return err + } + if c.Introspection.GRPC.TLS.ClientCert, err = dp.GetString(cfgKeyIntrospectionGRPCTLSClientCert); err != nil { + return err + } + if c.Introspection.GRPC.TLS.ClientKey, err = dp.GetString(cfgKeyIntrospectionGRPCTLSClientKey); err != nil { + return err + } + + if c.Introspection.AccessTokenScope, err = dp.GetStringSlice(cfgKeyIntrospectionAccessTokenScope); err != nil { + return err + } + + if c.Introspection.MinJWTVersion, err = dp.GetInt(cfgKeyIntrospectionMinJWTVer); err != nil { + return err + } + if c.Introspection.MinJWTVersion < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionMinJWTVer, fmt.Errorf("minimum JWT version should be non-negative")) + } + + // Claims cache + if c.Introspection.ClaimsCache.Enabled, err = dp.GetBool(cfgKeyIntrospectionClaimsCacheEnabled); err != nil { + return err + } + if c.Introspection.ClaimsCache.MaxEntries, err = dp.GetInt(cfgKeyIntrospectionClaimsCacheMaxEntries); err != nil { + return err + } + if c.Introspection.ClaimsCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionClaimsCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + if c.Introspection.ClaimsCache.TTL, err = dp.GetDuration(cfgKeyIntrospectionClaimsCacheTTL); err != nil { + return err + } + + // Negative cache + if c.Introspection.NegativeCache.Enabled, err = dp.GetBool(cfgKeyIntrospectionNegativeCacheEnabled); err != nil { + return err + } + if c.Introspection.NegativeCache.MaxEntries, err = dp.GetInt(cfgKeyIntrospectionNegativeCacheMaxEntries); err != nil { + return err + } + if c.Introspection.NegativeCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionNegativeCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + if c.Introspection.NegativeCache.TTL, err = dp.GetDuration(cfgKeyIntrospectionNegativeCacheTTL); err != nil { + return err + } + + return nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..81809cd --- /dev/null +++ b/config_test.go @@ -0,0 +1,265 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/acronis/go-appkit/config" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/jwt" +) + +func TestConfig_Set(t *testing.T) { + t.Run("ok", func(t *testing.T) { + cfgData := bytes.NewBufferString(` +auth: + httpClient: + requestTimeout: 1m + grpcClient: + requestTimeout: 2m + jwt: + trustedIssuers: + my-issuer1: https://my-issuer1.com/idp + my-issuer2: https://my-issuer2.com/idp + trustedIssuerUrls: + - https://*.my-company1.com/idp + - https://*.my-company2.com/idp + requireAudience: true + expectedAudience: + - https://*.my-company1.com + - https://*.my-company2.com + jwks: + httpclient: + timeout: 1m + cache: + updateMinInterval: 5m + introspection: + enabled: true + endpoint: https://my-idp.com/introspect + claimsCache: + enabled: true + maxEntries: 42000 + ttl: 42s + negativeCache: + enabled: true + maxEntries: 777 + ttl: 77s + accessTokenScope: + - token_introspector + minJWTVersion: 3 + grpc: + target: "127.0.0.1:1234" + tls: + enabled: true + caCert: ca-cert.pem + clientCert: client-cert.pem + clientKey: client-key.pem +`) + cfg := Config{} + err := config.NewDefaultLoader("").LoadFromReader(cfgData, config.DataTypeYAML, &cfg) + require.NoError(t, err) + require.Equal(t, time.Minute*1, cfg.HTTPClient.RequestTimeout) + require.Equal(t, time.Minute*2, cfg.GRPCClient.RequestTimeout) + require.Equal(t, cfg.JWT, JWTConfig{ + TrustedIssuers: map[string]string{ + "my-issuer1": "https://my-issuer1.com/idp", + "my-issuer2": "https://my-issuer2.com/idp", + }, + TrustedIssuerURLs: []string{ + "https://*.my-company1.com/idp", + "https://*.my-company2.com/idp", + }, + RequireAudience: true, + ExpectedAudience: []string{ + "https://*.my-company1.com", + "https://*.my-company2.com", + }, + ClaimsCache: ClaimsCacheConfig{ + MaxEntries: jwt.DefaultClaimsCacheMaxEntries, + }, + }) + require.Equal(t, time.Minute*5, cfg.JWKS.Cache.UpdateMinInterval) + require.Equal(t, cfg.Introspection, IntrospectionConfig{ + Enabled: true, + Endpoint: "https://my-idp.com/introspect", + ClaimsCache: IntrospectionCacheConfig{ + Enabled: true, + MaxEntries: 42000, + TTL: time.Second * 42, + }, + NegativeCache: IntrospectionCacheConfig{ + Enabled: true, + MaxEntries: 777, + TTL: time.Second * 77, + }, + AccessTokenScope: []string{"token_introspector"}, + MinJWTVersion: 3, + GRPC: IntrospectionGRPCConfig{ + Target: "127.0.0.1:1234", + TLS: GRPCTLSConfig{ + Enabled: true, + CACert: "ca-cert.pem", + ClientCert: "client-cert.pem", + ClientKey: "client-key.pem", + }, + }, + }) + }) +} + +func TestConfig_SetErrors(t *testing.T) { + tests := []struct { + name string + cfgData string + errKey string + errMsg string + }{ + { + name: "invalid trusted issuer URL", + cfgData: ` +auth: + jwt: + trustedIssuerURLs: + - ://invalid-url +`, + errKey: cfgKeyJWTTrustedIssuerURLs, + errMsg: "missing protocol scheme", + }, + { + name: "negative claims cache max entries", + cfgData: ` +auth: + jwt: + claimsCache: + maxEntries: -1 +`, + errKey: cfgKeyJWTClaimsCacheMaxEntries, + errMsg: "max entries should be non-negative", + }, + { + name: "invalid HTTP client timeout", + cfgData: ` +auth: + httpClient: + requestTimeout: invalid +`, + errKey: cfgKeyHTTPClientRequestTimeout, + errMsg: "invalid duration", + }, + { + name: "invalid gRPC client timeout", + cfgData: ` +auth: + grpcClient: + requestTimeout: invalid +`, + errKey: cfgKeyGRPCClientRequestTimeout, + errMsg: "invalid duration", + }, + { + name: "invalid cache update min interval", + cfgData: ` +auth: + jwks: + cache: + updateMinInterval: invalid +`, + errKey: cfgKeyJWKSCacheUpdateMinInterval, + errMsg: "invalid duration", + }, + { + name: "invalid introspection endpoint URL", + cfgData: ` +auth: + introspection: + endpoint: ://invalid-url +`, + errKey: cfgKeyIntrospectionEndpoint, + errMsg: "missing protocol scheme", + }, + { + name: "negative introspection claims cache max entries", + cfgData: ` +auth: + introspection: + claimsCache: + maxEntries: -1 +`, + errKey: cfgKeyIntrospectionClaimsCacheMaxEntries, + errMsg: "max entries should be non-negative", + }, + { + name: "negative introspection negative cache max entries", + cfgData: ` +auth: + introspection: + negativeCache: + maxEntries: -1 +`, + errKey: cfgKeyIntrospectionNegativeCacheMaxEntries, + errMsg: "max entries should be non-negative", + }, + { + name: "invalid introspection claims cache TTL", + cfgData: ` +auth: + introspection: + claimsCache: + ttl: invalid +`, + errKey: cfgKeyIntrospectionClaimsCacheTTL, + errMsg: "invalid duration", + }, + { + name: "invalid introspection negative cache TTL", + cfgData: ` +auth: + introspection: + negativeCache: + ttl: invalid +`, + errKey: cfgKeyIntrospectionNegativeCacheTTL, + errMsg: "invalid duration", + }, + { + name: "invalid introspection access token scope", + cfgData: ` +auth: + introspection: + accessTokenScope: {} +`, + errKey: cfgKeyIntrospectionAccessTokenScope, + errMsg: " unable to cast", + }, + { + name: "negative introspection min JWT version", + cfgData: ` +auth: + introspection: + minJWTVersion: -1 +`, + errKey: cfgKeyIntrospectionMinJWTVer, + errMsg: "minimum JWT version should be non-negative", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgData := bytes.NewBufferString(tt.cfgData) + cfg := Config{} + err := config.NewDefaultLoader("").LoadFromReader(cfgData, config.DataTypeYAML, &cfg) + require.ErrorContains(t, err, tt.errMsg) + require.Truef(t, strings.HasPrefix(err.Error(), tt.errKey), + "expected error starts with %q, got %q", tt.errKey, err.Error()) + }) + } +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..fc4062e --- /dev/null +++ b/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package auth provides high-level helpers and basic objects for authN/authZ. +package auth diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..f4e7405 --- /dev/null +++ b/example_test.go @@ -0,0 +1,196 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwt" +) + +func ExampleJWTAuthMiddleware() { + jwtConfig := JWTConfig{ + TrustedIssuerURLs: []string{"https://my-idp.com"}, + //TrustedIssuers: map[string]string{"my-idp": "https://my-idp.com"}, // Use TrustedIssuers if you have a custom issuer name. + } + jwtParser, _ := NewJWTParser(&Config{JWT: jwtConfig}) + authN := JWTAuthMiddleware("MyService", jwtParser) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authN(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + //jwtClaims := GetJWTClaimsFromContext(r.Context()) // GetJWTClaimsFromContext is a helper function to get JWT claims from context. + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + done := make(chan struct{}) + server := &http.Server{Addr: ":8080", Handler: srvMux} + go func() { + defer close(done) + _ = server.ListenAndServe() + }() + + time.Sleep(time.Second) // Wait for the server to start. + + client := &http.Client{Timeout: time.Second * 30} + + fmt.Println("GET http://localhost:8080/") + resp, _ := client.Get("http://localhost:8080/") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin without token") + resp, _ = client.Get("http://localhost:8080/admin") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin with invalid token") + req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/admin", http.NoBody) + req.Header["Authorization"] = []string{"Bearer invalid-token"} + resp, _ = client.Do(req) + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + _ = server.Shutdown(context.Background()) + <-done + + // Output: + // GET http://localhost:8080/ + // Status code: 200 + // Body: Hello, World! + // ------ + // GET http://localhost:8080/admin without token + // Status code: 401 + // Body: {"error":{"domain":"MyService","code":"bearerTokenMissing","message":"Authorization bearer token is missing."}} + // ------ + // GET http://localhost:8080/admin with invalid token + // Status code: 401 + // Body: {"error":{"domain":"MyService","code":"authenticationFailed","message":"Authentication is failed."}} +} + +func ExampleJWTAuthMiddlewareWithVerifyAccess() { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + roUserClaims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "my-idp", + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(2 * time.Hour)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "read-only-user"}}, + } + roUserToken := idptest.MustMakeTokenStringSignedWithTestKey(roUserClaims) + + adminClaims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "my-idp", + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(2 * time.Hour)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}}, + } + adminToken := idptest.MustMakeTokenStringSignedWithTestKey(adminClaims) + + jwtConfig := JWTConfig{TrustedIssuers: map[string]string{"my-idp": issuerConfigServer.URL}} + jwtParser, _ := NewJWTParser(&Config{JWT: jwtConfig}) + authOnlyAdmin := JWTAuthMiddleware("MyService", jwtParser, + WithJWTAuthMiddlewareVerifyAccess(NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}))) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authOnlyAdmin(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + done := make(chan struct{}) + server := &http.Server{Addr: ":8080", Handler: srvMux} + go func() { + defer close(done) + _ = server.ListenAndServe() + }() + + time.Sleep(time.Second) // Wait for the server to start. + + client := &http.Client{Timeout: time.Second * 30} + + fmt.Println("GET http://localhost:8080/") + resp, _ := client.Get("http://localhost:8080/") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin without token") + resp, _ = client.Get("http://localhost:8080/admin") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin with token of read-only user") + req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/admin", http.NoBody) + req.Header["Authorization"] = []string{"Bearer " + roUserToken} + resp, _ = client.Do(req) + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin with token of admin user") + req, _ = http.NewRequest(http.MethodGet, "http://localhost:8080/admin", http.NoBody) + req.Header["Authorization"] = []string{"Bearer " + adminToken} + resp, _ = client.Do(req) + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + _ = server.Shutdown(context.Background()) + <-done + + // Output: + // GET http://localhost:8080/ + // Status code: 200 + // Body: Hello, World! + // ------ + // GET http://localhost:8080/admin without token + // Status code: 401 + // Body: {"error":{"domain":"MyService","code":"bearerTokenMissing","message":"Authorization bearer token is missing."}} + // ------ + // GET http://localhost:8080/admin with token of read-only user + // Status code: 403 + // Body: {"error":{"domain":"MyService","code":"authorizationFailed","message":"Authorization is failed."}} + // ------ + // GET http://localhost:8080/admin with token of admin user + // Status code: 200 + // Body: Hello, admin! +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d710935 --- /dev/null +++ b/go.mod @@ -0,0 +1,58 @@ +module github.com/acronis/go-authkit + +go 1.20 + +require ( + github.com/acronis/go-appkit v1.3.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/uuid v1.6.0 + github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103 + github.com/prometheus/client_golang v1.20.4 + github.com/stretchr/testify v1.9.0 + github.com/vasayxtx/go-glob v1.2.0 + golang.org/x/sync v0.8.0 + google.golang.org/grpc v1.64.1 + google.golang.org/protobuf v1.34.2 +) + +require ( + code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9 // indirect + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/golang-lru v1.0.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/rs/xid v1.5.0 // indirect + github.com/sagikazarmark/locafero v0.6.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.7.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/viper v1.19.0 // indirect + github.com/ssgreg/logf v1.4.2 // indirect + github.com/ssgreg/logftext v1.1.1 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/throttled/throttled/v2 v2.12.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/time v0.6.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8a2a365 --- /dev/null +++ b/go.sum @@ -0,0 +1,221 @@ +code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9 h1:8KlrGCtoaWaaxVxi9KzED38kNIWa1qafh9bNSVZ6otk= +code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9/go.mod h1:eF2ZbltNI7Pv+8Cuyeksu9up5FN5konuH0trDJBuscw= +github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b h1:5/++qT1/z812ZqBvqQt6ToRswSuPZ/B33m6xVHRzADU= +github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b/go.mod h1:4+EPqMRApwwE/6yo6CxiHoSnBzjRr3jsqer7frxP8y4= +github.com/acronis/go-appkit v1.3.0 h1:IaX0DbD7HWp8ykqnK9F+c8757AmP4uBHVBe8J0Wv2sw= +github.com/acronis/go-appkit v1.3.0/go.mod h1:ouqWNe1/69fwjhx+2vV81Y6iqstfDzhmC6HpZ0E/gp4= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= +github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-redis/redis v6.15.8+incompatible h1:BKZuG6mCnRj5AOaWJXoCgf6rqTYnYJLe4en2hxT7r9o= +github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v8 v8.4.2/go.mod h1:A1tbYoHSa1fXwN+//ljcCYYJeLmVrwL9hbQN45Jdy0M= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= +github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= +github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103 h1:Z/i1e+gTZrmcGeZyWckaLfucYG6KYOXLWo4co8pZYNY= +github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103/go.mod h1:o9YPB5aGP8ob35Vy6+vyq3P3bWe7NQWzf+JLiXCiMaE= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= +github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk= +github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/ssgreg/logf v1.3.1/go.mod h1:s7bKemHNzeAi8OePMgR93dqfL4Swro4W3B2jSIyypl4= +github.com/ssgreg/logf v1.4.2 h1:J5qO5lVhFuHboQjYyTNt+0HlQifAYaLHgZLBxpDNIQQ= +github.com/ssgreg/logf v1.4.2/go.mod h1:s7bKemHNzeAi8OePMgR93dqfL4Swro4W3B2jSIyypl4= +github.com/ssgreg/logftext v1.1.1 h1:vq03mtTnUhmnznKwMeoW+mZrH8HxXUTzGDsxL6I6YMo= +github.com/ssgreg/logftext v1.1.1/go.mod h1:ONi7K7Bilp5Amyhq3sdVQ1lzzp4n4TyyVP4vpx962mA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/throttled/throttled/v2 v2.12.0 h1:IezKE1uHlYC/0Al05oZV6Ar+uN/znw3cy9J8banxhEY= +github.com/throttled/throttled/v2 v2.12.0/go.mod h1:+EAvrG2hZAQTx8oMpBu8fq6Xmm+d1P2luKK7fIY1Esc= +github.com/vasayxtx/go-glob v1.2.0 h1:t+/v9ROAeUVD2OLMcoS7yF6ojqaXSSRInAJ0vWOTU1g= +github.com/vasayxtx/go-glob v1.2.0/go.mod h1:wEj3yNgEm7emEVHCleh9WlNRoW9r3OsajUFgPvSLle0= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/otel v0.14.0/go.mod h1:vH5xEuwy7Rts0GNtsCW3HYQoZDY+OmBJ6t1bFGGlxgw= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211111213525-f221eed1c01e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/idptest/doc.go b/idptest/doc.go new file mode 100644 index 0000000..518a8d4 --- /dev/null +++ b/idptest/doc.go @@ -0,0 +1,10 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package idptest provides helper primitives and functions required for +// testing signing and key generation and a simple HTTP server +// with JWKS, issuer and IDP configuration endpoints. +package idptest diff --git a/idptest/grpc_server.go b/idptest/grpc_server.go new file mode 100644 index 0000000..35884af --- /dev/null +++ b/idptest/grpc_server.go @@ -0,0 +1,124 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/testutil" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" + + "github.com/acronis/go-authkit/idptoken/pb" +) + +type GRPCTokenCreator interface { + CreateToken(ctx context.Context, req *pb.CreateTokenRequest) (*pb.CreateTokenResponse, error) +} + +type GRPCTokenIntrospector interface { + IntrospectToken(ctx context.Context, req *pb.IntrospectTokenRequest) (*pb.IntrospectTokenResponse, error) +} + +type GRPCServer struct { + pb.UnimplementedIDPTokenServiceServer + *grpc.Server + addr atomic.Value + serverOpts []grpc.ServerOption + tokenIntrospector GRPCTokenIntrospector + tokenCreator GRPCTokenCreator +} + +type GRPCServerOption func(*GRPCServer) + +func WithGRPCAddr(addr string) GRPCServerOption { + return func(server *GRPCServer) { + server.addr.Store(addr) + } +} + +func WithGRPCServerOptions(opts ...grpc.ServerOption) GRPCServerOption { + return func(s *GRPCServer) { + s.serverOpts = opts + } +} + +func WithGRPCTokenIntrospector(tokenIntrospector GRPCTokenIntrospector) GRPCServerOption { + return func(s *GRPCServer) { + s.tokenIntrospector = tokenIntrospector + } +} + +func WithGRPCTokenCreator(tokenCreator GRPCTokenCreator) GRPCServerOption { + return func(s *GRPCServer) { + s.tokenCreator = tokenCreator + } +} + +// NewGRPCServer creates a new instance of GRPCServer. +func NewGRPCServer( + opts ...GRPCServerOption, +) *GRPCServer { + srv := &GRPCServer{} + for _, opt := range opts { + opt(srv) + } + srv.Server = grpc.NewServer(srv.serverOpts...) + pb.RegisterIDPTokenServiceServer(srv.Server, srv) + reflection.Register(srv.Server) + return srv +} + +// Addr returns the server address. +func (s *GRPCServer) Addr() string { + return s.addr.Load().(string) +} + +// Start starts the GRPC server +func (s *GRPCServer) Start() error { + addr, ok := s.addr.Load().(string) + if !ok { + addr = localhostWithDynamicPortAddr + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("listen tcp: %w", err) + } + s.addr.Store(ln.Addr().String()) + + go func() { _ = s.Serve(ln) }() + + return nil +} + +// StartAndWaitForReady starts the server waits for the server to start listening. +func (s *GRPCServer) StartAndWaitForReady(timeout time.Duration) error { + if err := s.Start(); err != nil { + return fmt.Errorf("start server: %w", err) + } + return testutil.WaitListeningServer(s.Addr(), timeout) +} + +func (s *GRPCServer) CreateToken(ctx context.Context, req *pb.CreateTokenRequest) (*pb.CreateTokenResponse, error) { + if s.tokenCreator != nil { + return s.tokenCreator.CreateToken(ctx, req) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateToken not implemented") +} + +func (s *GRPCServer) IntrospectToken(ctx context.Context, req *pb.IntrospectTokenRequest) (*pb.IntrospectTokenResponse, error) { + if s.tokenIntrospector != nil { + return s.tokenIntrospector.IntrospectToken(ctx, req) + } + return nil, status.Errorf(codes.Unimplemented, "method IntrospectToken not implemented") +} diff --git a/idptest/http_server.go b/idptest/http_server.go new file mode 100644 index 0000000..6f71464 --- /dev/null +++ b/idptest/http_server.go @@ -0,0 +1,175 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "fmt" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/testutil" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" +) + +const ( + OpenIDConfigurationPath = "/.well-known/openid-configuration" + JWKSEndpointPath = "/idp/keys" + TokenEndpointPath = "/idp/token" + TokenIntrospectionEndpointPath = "/idp/introspect_token" // nolint:gosec // This server is used for testing purposes only. +) + +const localhostWithDynamicPortAddr = "127.0.0.1:0" + +// HTTPClaimsProvider is an interface for providing JWT claims in HTTP handlers. +type HTTPClaimsProvider interface { + Provide(r *http.Request) jwt.Claims +} + +// HTTPTokenIntrospector is an interface for introspecting tokens. +type HTTPTokenIntrospector interface { + IntrospectToken(r *http.Request, token string) idptoken.IntrospectionResult +} + +type HTTPServerOption func(s *HTTPServer) + +// WithHTTPAddress is an option to set HTTP server address. +func WithHTTPAddress(addr string) HTTPServerOption { + return func(s *HTTPServer) { + s.addr.Store(addr) + } +} + +// WithHTTPOpenIDConfigurationHandler is an option to set custom handler for GET /.well-known/openid-configuration. +// Otherwise, OpenIDConfigurationHandler will be used. +func WithHTTPOpenIDConfigurationHandler(handler http.HandlerFunc) HTTPServerOption { + return func(s *HTTPServer) { + s.OpenIDConfigurationHandler = handler + } +} + +// WithHTTPKeysHandler is an option to set custom handler for GET /idp/keys. +// Otherwise, JWKSHandler will be used. +func WithHTTPKeysHandler(handler http.Handler) HTTPServerOption { + return func(s *HTTPServer) { + s.KeysHandler = handler + } +} + +// WithHTTPPublicJWKS is an option to set public JWKS for JWKSHandler which will be used for GET /idp/keys. +func WithHTTPPublicJWKS(keys []PublicJWK) HTTPServerOption { + return func(s *HTTPServer) { + s.KeysHandler = &JWKSHandler{PublicJWKS: keys} + } +} + +// WithHTTPTokenHandler is an option to set custom handler for POST /idp/token. +func WithHTTPTokenHandler(handler http.Handler) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenHandler = handler + } +} + +// WithHTTPClaimsProvider is an option to set ClaimsProvider for TokenHandler +// which will be used for POST /idp/token. +func WithHTTPClaimsProvider(claimsProvider HTTPClaimsProvider) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenHandler = &TokenHandler{ClaimsProvider: claimsProvider} + } +} + +// WithHTTPIntrospectTokenHandler is an option to set custom handler for POST /idp/introspect_token. +func WithHTTPIntrospectTokenHandler(handler http.Handler) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenIntrospectionHandler = handler + } +} + +// WithHTTPTokenIntrospector is an option to set TokenIntrospector for TokenIntrospectionHandler +// which will be used for POST /idp/introspect_token. +func WithHTTPTokenIntrospector(introspector HTTPTokenIntrospector) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenIntrospectionHandler = &TokenIntrospectionHandler{TokenIntrospector: introspector} + } +} + +// HTTPServer is a mock IDP server for testing purposes. +type HTTPServer struct { + *http.Server + addr atomic.Value + KeysHandler http.Handler + TokenHandler http.Handler + TokenIntrospectionHandler http.Handler + OpenIDConfigurationHandler http.Handler + Router *http.ServeMux +} + +// NewHTTPServer creates a new IDPMockServer with provided options. +func NewHTTPServer(options ...HTTPServerOption) *HTTPServer { + s := &HTTPServer{ + Router: http.NewServeMux(), + TokenHandler: &TokenHandler{}, + KeysHandler: &JWKSHandler{}, + TokenIntrospectionHandler: &TokenIntrospectionHandler{}, + } + s.OpenIDConfigurationHandler = &OpenIDConfigurationHandler{ + BaseURLFunc: s.URL, + JWKSURL: JWKSEndpointPath, + TokenEndpointURL: TokenEndpointPath, + IntrospectionEndpointURL: TokenIntrospectionEndpointPath, + } + + for _, opt := range options { + opt(s) + } + + s.Router.Handle(OpenIDConfigurationPath, s.OpenIDConfigurationHandler) + s.Router.Handle(JWKSEndpointPath, s.KeysHandler) + s.Router.Handle(TokenEndpointPath, s.TokenHandler) + s.Router.Handle(TokenIntrospectionEndpointPath, s.TokenIntrospectionHandler) + + // nolint:gosec // This server is used for testing purposes only. + s.Server = &http.Server{Handler: s.Router} + + return s +} + +// URL method returns the URL of the server. +func (s *HTTPServer) URL() string { + if srvURL := s.addr.Load(); srvURL != nil { + return "http://" + srvURL.(string) + } + return "" +} + +// Start starts the HTTPServer. +func (s *HTTPServer) Start() error { + addr, ok := s.addr.Load().(string) + if !ok { + addr = localhostWithDynamicPortAddr + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("listen tcp: %w", err) + } + s.addr.Store(ln.Addr().String()) + + go func() { _ = s.Server.Serve(ln) }() + + return nil +} + +// StartAndWaitForReady starts the server waits for the server to start listening. +func (s *HTTPServer) StartAndWaitForReady(timeout time.Duration) error { + if err := s.Start(); err != nil { + return fmt.Errorf("start server: %w", err) + } + return testutil.WaitListeningServer(s.addr.Load().(string), timeout) +} diff --git a/idptest/jwks_handler.go b/idptest/jwks_handler.go new file mode 100644 index 0000000..32256b0 --- /dev/null +++ b/idptest/jwks_handler.go @@ -0,0 +1,100 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" +) + +// TestKeyID is a key ID of the pre-defined key for testing. +const TestKeyID = "fac01c070cd08ba08809762da6e4f74af14e4790" + +// TestPlainPrivateJWK is a plaintext representation of the pre-defined private key for testing. +// nolint: lll +const TestPlainPrivateJWK = ` +{ + "alg": "RS256", + "d": "U4iZRcf35HT68wCF4cXPvOCU-aHbHoxkb99okPIf_pyexcznb3AjJCS9PRW3MnR1UkcDJ509Cwq7HTZ6dSPO8bMagGlFR4PttNgYzRg793r7ZamhzjPy_Udr35a79z6Q3rLBzeyFljhZ708cgU-tYw_7KpPytPN9cvr9MuvHtRWZlYuFRIql0PiOkq9hLMz_rLGb2rEmlQ9Bxk0tAOct9et-k8qgqwUjd0APgyHRxBU3gKUcbwgIb4KjYzypVjAW8Y3eIN8DwJg8P9AsMWdyKLW36exN_jZXGq6HrdecqV5hOGRELQ3Ok4sn-XMEyYVu9urQkZIRHGsbvUdwXskYKQ", + "dp": "jzn3HUi8J7QJF0JIT1VRbX8ngf4c7EDpV76rjTjdDNgGQJF6RZ34DfZoSWJlnoS2aJl2LW_Z2dXUnSi2JzVK_joEmFtMCe6vEiYKl3_2Avw43wfSJ1Kj7CTAOlRiifzdP9RosoYgznLjKq9_WyKlFVy7jXN11f-SBjiwYGyn_FM", + "dq": "pV6go3sp23q7jJP0DMrY0fmrC3AhWTGB3hb9w_KKVDV_J1ljSCpkSB55FfHkCsiTFfBHieCThNl7iWgl0eZJ7TtsdVUSj_dZfAMfi49nj6GBa5mjEUMSGgtUrqWNWf31rKKXz4Y4o0A6U8N8FuX34ANCj9hEBE0UIzdV_e7L1AE", + "e": "AQAB", + "kid": "fac01c070cd08ba08809762da6e4f74af14e4790", + "kty": "RSA", + "n": "mWeDDhcnVdKWbYGubOB7v1rZ395noYk-MFV0Ik78nLsJc1Ni3-GaWpJOTfCFivDP6DcS68Q04olx6_CleaDWU2KHeZE9PuJcW1_Xot3w1U2WZYpzl5_E5jqHjq1-nnOfe5Mq5SbpoZi3o3-QjktiSgaZ6w-575anM-6VhfxyS0s_DKGJHzyka1hJIoGb8vBstKS6oVLcgjQO3JR_Uy4XMdO9s3z-t3_4sO7qtHuEmqFUnaUx5MuLmZnV0hWyLHoNtEQZrf6X5lcnSj-6QerRihJdQeFDm494D96UwjKt70xgbAMvY-H2RcCJ5IqB2jvumqACt70twX7VCeS8FDMP_w", + "p": "w7rqemF-CmOU2X4p_4yzZaVq5CYmq9f-d1QLfK9AdMhIAPAlGxIkevXq6dAnjLWLJ9ksuOFkjWpoNI40JyhPJqif8U8WFyDqMsAEFif4HYVh-iR3NMr489lExBqx-YmmYHJ-pXxpcQhwAIbUkS-iF4eIx44JwVPNniU97Djy_ws", + "q": "yKQfjhWZSFzsn1CveQS6X6H1GtIbpWW9WBR0TFyUWrDtBxe1ivv21ie9hMDhpwk9t9ONUXqt-nNDMtK558q_fGKzMDwYIztX5vXRW9MMR6A7gylSGVspsUbk-egE2dXpwaGwdwr1RvFHEjBNeJQWxvQH-g-QNhJQm6gBdzn6210", + "qi": "Yt33e8KxCstCfgD4MvPg-uTVj6o2f893zbast8b_yunEBZK-c4WnJ73Taj7lOB2iME97XrBsx3f-jdslt6xHd9h0mam_Fi53JxQDoiyPcLWfcgcMY2w4jjoY_-Iqtnnisf7tHGgrba9eyNHRl91oXFgoaduNmeUs1z_yF_GARJo", + "use": "sig" +} +` + +type PublicJWK struct { + Alg string `json:"alg"` + E string `json:"e"` + Kid string `json:"kid"` + Kty string `json:"kty"` + N string `json:"n"` + Use string `json:"use"` +} + +func GetTestPublicJWKS() []PublicJWK { + return []PublicJWK{ + { + Alg: "RS256", + E: "AQAB", + Kid: TestKeyID, + Kty: "RSA", + N: "mWeDDhcnVdKWbYGubOB7v1rZ395noYk-MFV0Ik78nLsJc1Ni3-GaWpJOTfCFivDP6DcS68Q04olx6_CleaDWU2KHeZE9PuJcW1_Xot3w1U2WZYpzl5_E5jqHjq1-nnOfe5Mq5SbpoZi3o3-QjktiSgaZ6w-575anM-6VhfxyS0s_DKGJHzyka1hJIoGb8vBstKS6oVLcgjQO3JR_Uy4XMdO9s3z-t3_4sO7qtHuEmqFUnaUx5MuLmZnV0hWyLHoNtEQZrf6X5lcnSj-6QerRihJdQeFDm494D96UwjKt70xgbAMvY-H2RcCJ5IqB2jvumqACt70twX7VCeS8FDMP_w", // nolint:lll + Use: "sig", + }, + { + Alg: "RS256", + E: "AQAB", + Kid: "737c5114f09b5ed05276bd4b520245982f7fb29f", + Kty: "RSA", + N: "51gGypRFvhTziiCLW3emsFx80G3ljpoYdDdieYM-yfvv6cfpkiEnxRRig5JdJ62vrENgbZi1GZpvTs3B7ly7Z4FI6EM-5e8vIkQSYuE3sXU7QsxEFjtMUm31kao4179gmIIrycHl5M1HE2FU2Ssgf7VuKIVmLvDypNHgBb8cV2XKu_PiGHk2turbKZXxegJTiMBYrgKSaEuBUi3WC3j-onHmQriThchQujmXVMFQ-5syNkUX7hM8PKKONkFUhKANnh0Om8_Sc3bcYZAIoFA2cD-PXopJUQa8GLRfWLExVHRvp-4_vtDYbEAeipPYz2cRmEoMKiLRk8ZpLI6M71ugLQ", // nolint:lll + Use: "sig", + }, + } +} + +// JWKSHandler is an HTTP handler that responds JWKS. +type JWKSHandler struct { + servedCount atomic.Uint64 + PublicJWKS []PublicJWK +} + +func (h *JWKSHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(rw, "Only GET method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + rw.Header().Set("Content-Type", "application/json") + publicJWKS := h.PublicJWKS + if len(publicJWKS) == 0 { + publicJWKS = GetTestPublicJWKS() + } + if err := json.NewEncoder(rw).Encode(PublicJWKSResponse{Keys: publicJWKS}); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +// ServedCount returns the number of times JWKS handler has been served. +func (h *JWKSHandler) ServedCount() uint64 { + return h.servedCount.Load() +} + +type PublicJWKSResponse struct { + Keys []PublicJWK `json:"keys"` +} diff --git a/idptest/jwt.go b/idptest/jwt.go new file mode 100644 index 0000000..172dee5 --- /dev/null +++ b/idptest/jwt.go @@ -0,0 +1,86 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "crypto" + "encoding/json" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/mendsley/gojwk" + + "github.com/acronis/go-authkit/idptoken" +) + +// SignToken signs token with key. +func SignToken(token *jwtgo.Token, rsaPrivateKey interface{}) (string, error) { + return token.SignedString(rsaPrivateKey) +} + +// MakeTokenStringWithHeader create test signed token with claims and headers. +func MakeTokenStringWithHeader( + claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, +) (string, error) { + token := jwtgo.NewWithClaims(jwtgo.SigningMethodRS256, claims) + token.Header["typ"] = idptoken.JWTTypeAccessToken + token.Header["kid"] = kid + for k, v := range header { + token.Header[k] = v + } + return SignToken(token, rsaPrivateKey) +} + +// MustMakeTokenStringWithHeader create test signed token with claims and headers. +// It panics if error occurs. +func MustMakeTokenStringWithHeader( + claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, +) string { + token, err := MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, header) + if err != nil { + panic(err) + } + return token +} + +// MakeTokenString create signed token with claims. +func MakeTokenString(claims jwtgo.Claims, kid string, rsaPrivateKey interface{}) (string, error) { + return MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, nil) +} + +// MustMakeTokenString create signed token with claims. +// It panics if error occurs. +func MustMakeTokenString(claims jwtgo.Claims, kid string, rsaPrivateKey interface{}) string { + token, err := MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, nil) + if err != nil { + panic(err) + } + return token +} + +// GetTestRSAPrivateKey returns pre-defined RSA private key for testing. +func GetTestRSAPrivateKey() crypto.PrivateKey { + var privKey gojwk.Key + _ = json.Unmarshal([]byte(TestPlainPrivateJWK), &privKey) + rsaPrivKey, _ := privKey.DecodePrivateKey() + return rsaPrivKey +} + +// MakeTokenStringSignedWithTestKey create test token signed with the pre-defined private key (TestKeyID) for testing. +func MakeTokenStringSignedWithTestKey(claims jwtgo.Claims) (string, error) { + return MakeTokenStringWithHeader(claims, TestKeyID, GetTestRSAPrivateKey(), nil) +} + +// MustMakeTokenStringSignedWithTestKey create test token signed +// with the pre-defined private key (TestKeyID) for testing. +// It panics if error occurs. +func MustMakeTokenStringSignedWithTestKey(claims jwtgo.Claims) string { + token, err := MakeTokenStringSignedWithTestKey(claims) + if err != nil { + panic(err) + } + return token +} diff --git a/idptest/jwt_test.go b/idptest/jwt_test.go new file mode 100644 index 0000000..a8056e9 --- /dev/null +++ b/idptest/jwt_test.go @@ -0,0 +1,54 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +const testIss = "test-issuer" + +func TestMakeTokenStringWithHeader(t *testing.T) { + jwksServer := httptest.NewServer(&JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + logger := log.NewDisabledLogger() + + jwtClaims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{ + {ResourceNamespace: "policy_manager", Role: "admin"}, + }, + } + + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + parsedClaims, err := parser.Parse(context.Background(), MustMakeTokenStringSignedWithTestKey(jwtClaims)) + require.NoError(t, err) + require.Equal( + t, + []jwt.AccessPolicy{{ResourceNamespace: "policy_manager", Role: "admin"}}, + parsedClaims.Scope, + ) +} diff --git a/idptest/openid_configuration_handler.go b/idptest/openid_configuration_handler.go new file mode 100644 index 0000000..a790bc2 --- /dev/null +++ b/idptest/openid_configuration_handler.go @@ -0,0 +1,60 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" + + "github.com/acronis/go-authkit/internal/idputil" +) + +// OpenIDConfigurationHandler is an HTTP handler that responds token's issuer OpenID configuration. +type OpenIDConfigurationHandler struct { + servedCount atomic.Uint64 + BaseURLFunc func() string // for cases when 'host:port' of providers' addresses to be determined during runtime + JWKSURL string + TokenEndpointURL string + IntrospectionEndpointURL string +} + +func (h *OpenIDConfigurationHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(rw, "Only GET method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + openIDCfg := idputil.OpenIDConfiguration{ + TokenURL: h.makeEndpointURL(h.TokenEndpointURL, TokenIntrospectionEndpointPath), + IntrospectionEndpoint: h.makeEndpointURL(h.IntrospectionEndpointURL, TokenIntrospectionEndpointPath), + JWKSURI: h.makeEndpointURL(h.JWKSURL, JWKSEndpointPath), + } + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(openIDCfg); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +func (h *OpenIDConfigurationHandler) makeEndpointURL(endpointURL string, defaultPath string) string { + if endpointURL == "" { + endpointURL = defaultPath + } + if h.BaseURLFunc != nil { + endpointURL = h.BaseURLFunc() + endpointURL + } + return endpointURL +} + +// ServedCount returns the number of times the handler has been served. +func (h *OpenIDConfigurationHandler) ServedCount() uint64 { + return h.servedCount.Load() +} diff --git a/idptest/token_handlers.go b/idptest/token_handlers.go new file mode 100644 index 0000000..a260aa6 --- /dev/null +++ b/idptest/token_handlers.go @@ -0,0 +1,103 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" + "time" +) + +// TokenHandler is an implementation of a handler responding with IDP token. +type TokenHandler struct { + servedCount atomic.Uint64 + ClaimsProvider HTTPClaimsProvider +} + +func (h *TokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(rw, "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + if h.ClaimsProvider == nil { + http.Error(rw, "ClaimsProvider for TokenHandler is not configured", http.StatusInternalServerError) + return + } + + claims := h.ClaimsProvider.Provide(r) + + token, err := MakeTokenStringWithHeader(claims, TestKeyID, GetTestRSAPrivateKey(), nil) + if err != nil { + http.Error(rw, fmt.Sprintf("Claims provider failed generate token: %v", err), http.StatusInternalServerError) + return + } + + expiresIn := claims.ExpiresAt.Unix() - time.Now().UTC().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + response := struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + }{ + AccessToken: token, + ExpiresIn: expiresIn, + } + rw.Header().Set("Content-Type", "application/json") + if err = json.NewEncoder(rw).Encode(response); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +// ServedCount returns the number of times the handler has been served. +func (h *TokenHandler) ServedCount() uint64 { + return h.servedCount.Load() +} + +type TokenIntrospectionHandler struct { + servedCount atomic.Uint64 + TokenIntrospector HTTPTokenIntrospector +} + +func (h *TokenIntrospectionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(rw, "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + if h.TokenIntrospector == nil { + http.Error(rw, "HTTPTokenIntrospector for TokenIntrospectionHandler is not configured", http.StatusInternalServerError) + return + } + + token := r.FormValue("token") + if token == "" { + http.Error(rw, "Token is required", http.StatusBadRequest) + return + } + introspectResult := h.TokenIntrospector.IntrospectToken(r, token) + + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(introspectResult); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +// ServedCount returns the number of times the handler has been served. +func (h *TokenIntrospectionHandler) ServedCount() uint64 { + return h.servedCount.Load() +} diff --git a/idptest/token_provider.go b/idptest/token_provider.go new file mode 100644 index 0000000..1a3a386 --- /dev/null +++ b/idptest/token_provider.go @@ -0,0 +1,26 @@ +package idptest + +import ( + "context" + "sync/atomic" +) + +type SimpleTokenProvider struct { + token atomic.Value +} + +func NewSimpleTokenProvider(token string) *SimpleTokenProvider { + tp := &SimpleTokenProvider{} + tp.SetToken(token) + return tp +} + +func (m *SimpleTokenProvider) GetToken(ctx context.Context, scope ...string) (string, error) { + return m.token.Load().(string), nil +} + +func (m *SimpleTokenProvider) Invalidate() {} + +func (m *SimpleTokenProvider) SetToken(token string) { + m.token.Store(token) +} diff --git a/idptoken/caching_introspector.go b/idptoken/caching_introspector.go new file mode 100644 index 0000000..e00a807 --- /dev/null +++ b/idptoken/caching_introspector.go @@ -0,0 +1,194 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + "unsafe" + + "github.com/acronis/go-appkit/lrucache" + + "github.com/acronis/go-authkit/jwt" +) + +const ( + DefaultIntrospectionClaimsCacheMaxEntries = 1000 + DefaultIntrospectionClaimsCacheTTL = 1 * time.Minute + DefaultIntrospectionNegativeCacheMaxEntries = 1000 + DefaultIntrospectionNegativeCacheTTL = 10 * time.Minute +) + +type IntrospectionClaimsCacheItem struct { + Claims *jwt.Claims + TokenType string + CreatedAt time.Time +} + +type IntrospectionClaimsCache interface { + Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionClaimsCacheItem, bool) + Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionClaimsCacheItem) + Purge(ctx context.Context) + Len(ctx context.Context) int +} + +type IntrospectionNegativeCacheItem struct { + CreatedAt time.Time +} + +type IntrospectionNegativeCache interface { + Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionNegativeCacheItem, bool) + Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionNegativeCacheItem) + Purge(ctx context.Context) + Len(ctx context.Context) int +} + +type CachingIntrospectorOpts struct { + IntrospectorOpts + ClaimsCache CachingIntrospectorCacheOpts + NegativeCache CachingIntrospectorCacheOpts +} + +type CachingIntrospectorCacheOpts struct { + Enabled bool + MaxEntries int + TTL time.Duration +} + +type CachingIntrospector struct { + *Introspector + ClaimsCache IntrospectionClaimsCache + NegativeCache IntrospectionNegativeCache + claimsCacheTTL time.Duration + negativeCacheTTL time.Duration +} + +func NewCachingIntrospector(tokenProvider IntrospectionTokenProvider) (*CachingIntrospector, error) { + return NewCachingIntrospectorWithOpts(tokenProvider, CachingIntrospectorOpts{}) +} + +func NewCachingIntrospectorWithOpts( + tokenProvider IntrospectionTokenProvider, opts CachingIntrospectorOpts, +) (*CachingIntrospector, error) { + if !opts.ClaimsCache.Enabled && !opts.NegativeCache.Enabled { + return nil, fmt.Errorf("at least one of claims or negative cache must be enabled") + } + + introspector := NewIntrospectorWithOpts(tokenProvider, opts.IntrospectorOpts) + + // Building claims cache if needed. + var claimsCache IntrospectionClaimsCache = &disabledIntrospectionClaimsCache{} + if opts.ClaimsCache.Enabled { + if opts.ClaimsCache.TTL == 0 { + opts.ClaimsCache.TTL = DefaultIntrospectionClaimsCacheTTL + } + if opts.ClaimsCache.MaxEntries == 0 { + opts.ClaimsCache.MaxEntries = DefaultIntrospectionClaimsCacheMaxEntries + } + cache, err := lrucache.New[[sha256.Size]byte, IntrospectionClaimsCacheItem]( + opts.ClaimsCache.MaxEntries, introspector.promMetrics.TokenClaimsCache) + if err != nil { + return nil, err + } + claimsCache = &introspectionCacheLRUAdapter[[sha256.Size]byte, IntrospectionClaimsCacheItem]{cache} + } + + // Building negative cache if needed. + var negativeCache IntrospectionNegativeCache = &disabledIntrospectionNegativeCache{} + if opts.NegativeCache.Enabled { + if opts.NegativeCache.TTL == 0 { + opts.NegativeCache.TTL = DefaultIntrospectionNegativeCacheTTL + } + if opts.NegativeCache.MaxEntries == 0 { + opts.NegativeCache.MaxEntries = DefaultIntrospectionNegativeCacheMaxEntries + } + cache, err := lrucache.New[[sha256.Size]byte, IntrospectionNegativeCacheItem]( + opts.NegativeCache.MaxEntries, introspector.promMetrics.TokenNegativeCache) + if err != nil { + return nil, err + } + negativeCache = &introspectionCacheLRUAdapter[[sha256.Size]byte, IntrospectionNegativeCacheItem]{cache} + } + + return &CachingIntrospector{ + Introspector: introspector, + ClaimsCache: claimsCache, + NegativeCache: negativeCache, + claimsCacheTTL: opts.ClaimsCache.TTL, + negativeCacheTTL: opts.NegativeCache.TTL, + }, nil +} + +func (i *CachingIntrospector) IntrospectToken(ctx context.Context, token string) (IntrospectionResult, error) { + cacheKey := sha256.Sum256( + unsafe.Slice(unsafe.StringData(token), len(token))) // nolint:gosec // prevent redundant slice copying + + if c, ok := i.ClaimsCache.Get(ctx, cacheKey); ok && c.CreatedAt.Add(i.claimsCacheTTL).After(time.Now()) { + return IntrospectionResult{Active: true, TokenType: c.TokenType, Claims: *c.Claims}, nil + } + if c, ok := i.NegativeCache.Get(ctx, cacheKey); ok && c.CreatedAt.Add(i.negativeCacheTTL).After(time.Now()) { + return IntrospectionResult{Active: false}, nil + } + + introspectionResult, err := i.Introspector.IntrospectToken(ctx, token) + if err != nil { + return IntrospectionResult{}, err + } + if introspectionResult.Active { + i.ClaimsCache.Add(ctx, cacheKey, IntrospectionClaimsCacheItem{ + Claims: &introspectionResult.Claims, + TokenType: introspectionResult.TokenType, + CreatedAt: time.Now(), + }) + } else { + i.NegativeCache.Add(ctx, cacheKey, IntrospectionNegativeCacheItem{CreatedAt: time.Now()}) + } + + return introspectionResult, nil +} + +type introspectionCacheLRUAdapter[K comparable, V any] struct { + cache *lrucache.LRUCache[K, V] +} + +func (a *introspectionCacheLRUAdapter[K, V]) Get(_ context.Context, key K) (V, bool) { + return a.cache.Get(key) +} + +func (a *introspectionCacheLRUAdapter[K, V]) Add(_ context.Context, key K, val V) { + a.cache.Add(key, val) +} + +func (a *introspectionCacheLRUAdapter[K, V]) Purge(ctx context.Context) { + a.cache.Purge() +} + +func (a *introspectionCacheLRUAdapter[K, V]) Len(ctx context.Context) int { + return a.cache.Len() +} + +type disabledIntrospectionClaimsCache struct{} + +func (c *disabledIntrospectionClaimsCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionClaimsCacheItem, bool) { + return IntrospectionClaimsCacheItem{}, false +} +func (c *disabledIntrospectionClaimsCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionClaimsCacheItem) { +} +func (c *disabledIntrospectionClaimsCache) Purge(ctx context.Context) {} +func (c *disabledIntrospectionClaimsCache) Len(ctx context.Context) int { return 0 } + +type disabledIntrospectionNegativeCache struct{} + +func (c *disabledIntrospectionNegativeCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionNegativeCacheItem, bool) { + return IntrospectionNegativeCacheItem{}, false +} +func (c *disabledIntrospectionNegativeCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionNegativeCacheItem) { +} +func (c *disabledIntrospectionNegativeCache) Purge(ctx context.Context) {} +func (c *disabledIntrospectionNegativeCache) Len(ctx context.Context) int { return 0 } diff --git a/idptoken/caching_introspector_test.go b/idptoken/caching_introspector_test.go new file mode 100644 index 0000000..b6575f7 --- /dev/null +++ b/idptoken/caching_introspector_test.go @@ -0,0 +1,255 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "context" + "net/http" + "net/url" + gotesting "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/testing" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +func TestCachingIntrospector_IntrospectToken(t *gotesting.T) { + serverIntrospector := testing.NewHTTPServerTokenIntrospectorMock() + + idpSrv := idptest.NewHTTPServer(idptest.WithHTTPTokenIntrospector(serverIntrospector)) + require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = idpSrv.Shutdown(context.Background()) }() + + const accessToken = "access-token-with-introspection-permission" + tokenProvider := idptest.NewSimpleTokenProvider(accessToken) + + logger := log.NewDisabledLogger() + jwtParser := jwt.NewParser(jwks.NewClient(http.DefaultClient, logger), logger) + require.NoError(t, jwtParser.AddTrustedIssuerURL(idpSrv.URL())) + serverIntrospector.JWTParser = jwtParser + + jwtExpiresAtInFuture := jwtgo.NewNumericDate(time.Now().Add(time.Hour)) + jwtIssuer := idpSrv.URL() + jwtSubject := uuid.NewString() + jwtID := uuid.NewString() + jwtScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-" + uuid.NewString(), + }} + + expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: idpSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + }, + }) + activeJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + }) + + opaqueToken1 := "opaque-token-" + uuid.NewString() + opaqueToken2 := "opaque-token-" + uuid.NewString() + opaqueToken3 := "opaque-token-" + uuid.NewString() + opaqueToken1Scope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + opaqueToken2Scope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "event-manager", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + + serverIntrospector.SetScopeForJWTID(jwtID, jwtScope) + serverIntrospector.SetResultForToken(opaqueToken1, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}) + serverIntrospector.SetResultForToken(opaqueToken2, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}) + serverIntrospector.SetResultForToken(opaqueToken3, idptoken.IntrospectionResult{Active: false}) + + tests := []struct { + name string + introspectorOpts idptoken.CachingIntrospectorOpts + tokens []string + expectedSrvCalled []bool + expectedResult []idptoken.IntrospectionResult + checkError []func(t *gotesting.T, err error) + checkIntrospector func(t *gotesting.T, introspector *idptoken.CachingIntrospector) + delay time.Duration + }{ + { + name: "error, token is not introspectable", + tokens: []string{"", "opaque-token"}, + expectedSrvCalled: []bool{false, false}, + introspectorOpts: idptoken.CachingIntrospectorOpts{ + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + checkError: []func(t *gotesting.T, err error){ + func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "token is missing") + }, + func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "no JWT header found") + }, + }, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 0, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is expired JWT", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + tokens: repeat(expiredJWT, 2), + expectedSrvCalled: []bool{true, false}, + expectedResult: []idptoken.IntrospectionResult{{Active: false}, {Active: false}}, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is JWT", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + tokens: repeat(activeJWT, 2), + expectedSrvCalled: []bool{true, false}, + expectedResult: repeat(idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + Scope: jwtScope, + }, + }, 2), + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 0, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, static introspection endpoint, introspected token is opaque", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + IntrospectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + tokens: []string{opaqueToken1, opaqueToken1, opaqueToken2, opaqueToken2, opaqueToken3, opaqueToken3}, + expectedSrvCalled: []bool{true, false, true, false, true, false}, + expectedResult: []idptoken.IntrospectionResult{ + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}, + {Active: false}, + {Active: false}, + }, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 2, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, cache has ttl", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + IntrospectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + }, + tokens: []string{opaqueToken1, opaqueToken1, opaqueToken3, opaqueToken3}, + expectedSrvCalled: []bool{true, true, true, true}, + expectedResult: []idptoken.IntrospectionResult{ + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: false}, + {Active: false}, + }, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + }, + delay: 200 * time.Millisecond, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + introspector, err := idptoken.NewCachingIntrospectorWithOpts(tokenProvider, tt.introspectorOpts) + require.NoError(t, err) + require.NoError(t, introspector.AddTrustedIssuerURL(idpSrv.URL())) + + for i, token := range tt.tokens { + serverIntrospector.ResetCallsInfo() + + result, introspectErr := introspector.IntrospectToken(context.Background(), token) + if i < len(tt.checkError) { + tt.checkError[i](t, introspectErr) + } else { + require.NoError(t, introspectErr) + require.Equal(t, tt.expectedResult[i], result) + } + + require.Equal(t, tt.expectedSrvCalled[i], serverIntrospector.Called) + if tt.expectedSrvCalled[i] { + require.Equal(t, token, serverIntrospector.LastIntrospectedToken) + require.Equal(t, "Bearer "+accessToken, serverIntrospector.LastAuthorizationHeader) + require.Equal(t, url.Values{"token": {token}}, serverIntrospector.LastFormValues) + } + + time.Sleep(tt.delay) + } + + if tt.checkIntrospector != nil { + tt.checkIntrospector(t, introspector) + } + }) + } +} + +func repeat[V any](v V, n int) []V { + s := make([]V, n) + for i := range s { + s[i] = v + } + return s +} diff --git a/idptoken/config.go b/idptoken/config.go new file mode 100644 index 0000000..92a3349 --- /dev/null +++ b/idptoken/config.go @@ -0,0 +1,61 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "fmt" + + "github.com/acronis/go-appkit/config" +) + +const ( + cfgKeyIDPURL = "idp.url" + cfgKeyIDPClientID = "idp.clientId" + cfgKeyIDPClientSecret = "idp.clientSecret" +) + +// Config is a configuration for IDP token source. +type Config struct { + URL string + ClientID string + ClientSecret string +} + +var _ config.Config = (*Config)(nil) + +// NewConfig creates a new configuration for IDP token source. +func NewConfig() *Config { + return &Config{} +} + +// SetProviderDefaults sets the default values for the configuration. +func (c *Config) SetProviderDefaults(_ config.DataProvider) { +} + +// Set sets the configuration from the given data provider. +func (c *Config) Set(dp config.DataProvider) (err error) { + if c.URL, err = dp.GetString(cfgKeyIDPURL); err != nil { + return err + } + if c.URL == "" { + return dp.WrapKeyErr(cfgKeyIDPURL, fmt.Errorf("IDP URL is required")) + } + if c.ClientID, err = dp.GetString(cfgKeyIDPClientID); err != nil { + return err + } + if c.ClientID == "" { + return dp.WrapKeyErr(cfgKeyIDPClientID, fmt.Errorf("IDP client ID is required")) + } + if c.ClientSecret, err = dp.GetString(cfgKeyIDPClientSecret); err != nil { + return err + } + if c.ClientSecret == "" { + return dp.WrapKeyErr(cfgKeyIDPClientSecret, fmt.Errorf("IDP client secret is required")) + } + + return nil +} diff --git a/idptoken/config_test.go b/idptoken/config_test.go new file mode 100644 index 0000000..6736ede --- /dev/null +++ b/idptoken/config_test.go @@ -0,0 +1,105 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "bytes" + "os" + "testing" + + "github.com/acronis/go-appkit/config" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptoken" +) + +func TestConfig(t *testing.T) { + type testCase struct { + name string + cfgData string + expectErr bool + errMsg string + setupEnv map[string]string // Environment variables to set + } + + testCases := []testCase{ + { + name: "valid config", + cfgData: ` +idp: + url: https://idp.example.com + clientId: client-id + clientSecret: client-secret +`, + expectErr: false, + }, + { + name: "missing url", + cfgData: ` +idp: + clientId: client-id + clientSecret: client-secret +`, + expectErr: true, + errMsg: `idp.url: IDP URL is required`, + }, + { + name: "missing client ID", + cfgData: ` +idp: + url: https://idp.example.com + clientSecret: client-secret +`, + expectErr: true, + errMsg: `idp.clientId: IDP client ID is required`, + }, + { + name: "missing client secret", + cfgData: ` +idp: + url: https://idp.example.com + clientId: client-id +`, + expectErr: true, + errMsg: `idp.clientSecret: IDP client secret is required`, + }, + { + name: "valid config from Env", + cfgData: ` +idp: + url: https://idp.example.com +`, + setupEnv: map[string]string{ + "IDP_CLIENTID": "client-id", + "IDP_CLIENTSECRET": "client-secret", + }, + expectErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup environment variables if needed + for k, v := range tc.setupEnv { + err := os.Setenv(k, v) + require.NoError(t, err) + defer func(k string) { + require.NoError(t, os.Unsetenv(k)) + }(k) + } + + cfg := idptoken.NewConfig() + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewBufferString(tc.cfgData), config.DataTypeYAML, cfg) + + if tc.expectErr { + require.EqualError(t, err, tc.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/idptoken/doc.go b/idptoken/doc.go new file mode 100644 index 0000000..11bea92 --- /dev/null +++ b/idptoken/doc.go @@ -0,0 +1,10 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package idptoken provides a robust way to request access tokens from IDP. +// Provider is to be used for a single token source. +// MultiSourceProvider to be used for multiple token sources. +package idptoken diff --git a/idptoken/grpc_client.go b/idptoken/grpc_client.go new file mode 100644 index 0000000..c545cea --- /dev/null +++ b/idptoken/grpc_client.go @@ -0,0 +1,182 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "google.golang.org/grpc" + grpccodes "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" + grpcstatus "google.golang.org/grpc/status" + + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/metrics" + "github.com/acronis/go-authkit/jwt" +) + +// GRPCClientOpts contains options for the GRPCClient. +type GRPCClientOpts struct { + // Logger is a logger for the client. + Logger log.FieldLogger + + // RequestTimeout is a timeout for the gRPC requests. + RequestTimeout time.Duration + + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same library. + PrometheusLibInstanceLabel string +} + +// GRPCClient is a client for the IDP token service that uses gRPC. +type GRPCClient struct { + client pb.IDPTokenServiceClient + clientConn *grpc.ClientConn + reqTimeout time.Duration + promMetrics *metrics.PrometheusMetrics +} + +const grpcMetaAuthorization = "authorization" + +// NewGRPCClient creates a new GRPCClient instance that communicates with the IDP token service. +func NewGRPCClient( + target string, transportCreds credentials.TransportCredentials, +) (*GRPCClient, error) { + return NewGRPCClientWithOpts(target, transportCreds, GRPCClientOpts{}) +} + +// NewGRPCClientWithOpts creates a new GRPCClient instance that communicates with the IDP token service +// with the specified options. +func NewGRPCClientWithOpts( + target string, transportCreds credentials.TransportCredentials, opts GRPCClientOpts, +) (*GRPCClient, error) { + if opts.Logger == nil { + opts.Logger = log.NewDisabledLogger() + } + if opts.RequestTimeout == 0 { + opts.RequestTimeout = time.Second * 30 + } + dialCtx := context.Background() // context.Background() is ok since we don't use grpc.WithBlock() + conn, err := grpc.DialContext(dialCtx, target, + grpc.WithTransportCredentials(transportCreds), + grpc.WithStatsHandler(&statsHandler{logger: opts.Logger}), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + ) + if err != nil { + return nil, fmt.Errorf("dial to %q: %w", target, err) + } + return &GRPCClient{ + client: pb.NewIDPTokenServiceClient(conn), + clientConn: conn, + reqTimeout: opts.RequestTimeout, + promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "grpc_client"), + }, nil +} + +// Close closes the client gRPC connection. +func (c *GRPCClient) Close() error { + return c.clientConn.Close() +} + +// IntrospectToken introspects the token using the IDP token service. +func (c *GRPCClient) IntrospectToken( + ctx context.Context, token string, scopeFilter []IntrospectionScopeFilterAccessPolicy, accessToken string, +) (IntrospectionResult, error) { + req := pb.IntrospectTokenRequest{ + Token: token, + ScopeFilter: make([]*pb.IntrospectionScopeFilter, len(scopeFilter)), + } + for i := range scopeFilter { + req.ScopeFilter[i] = &pb.IntrospectionScopeFilter{ResourceNamespace: scopeFilter[i].ResourceNamespace} + } + + ctx = metadata.AppendToOutgoingContext(ctx, grpcMetaAuthorization, makeBearerToken(accessToken)) + + ctx, ctxCancel := context.WithTimeout(ctx, c.reqTimeout) + defer ctxCancel() + + const methodName = "IDPTokenService/IntrospectToken" + startTime := time.Now() + resp, err := c.client.IntrospectToken(ctx, &req) + elapsed := time.Since(startTime) + if err != nil { + var code grpccodes.Code + if st, ok := grpcstatus.FromError(err); ok { + code = st.Code() + } + c.promMetrics.ObserveGRPCClientRequest(methodName, code, elapsed) + if code == grpccodes.Unauthenticated { + return IntrospectionResult{}, ErrTokenIntrospectionUnauthenticated + } + return IntrospectionResult{}, fmt.Errorf("introspect token: %w", err) + } + c.promMetrics.ObserveGRPCClientRequest(methodName, grpccodes.OK, elapsed) + + res := IntrospectionResult{ + Active: resp.GetActive(), + TokenType: resp.GetTokenType(), + Claims: jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: resp.GetIss(), + Subject: resp.GetSub(), + Audience: resp.GetAud(), + ID: resp.GetJti(), + }, + SubType: resp.GetSubType(), + ClientID: resp.GetClientId(), + OwnerTenantUUID: resp.GetOwnerTenantUuid(), + Scope: make([]jwt.AccessPolicy, len(resp.GetScope())), + }, + } + if resp.GetExp() != 0 { + res.Claims.ExpiresAt = jwtgo.NewNumericDate(time.Unix(resp.GetExp(), 0)) + } + for i, s := range resp.GetScope() { + res.Claims.Scope[i] = jwt.AccessPolicy{ + ResourceNamespace: s.GetResourceNamespace(), + Role: s.GetRoleName(), + ResourceServerID: s.GetResourceServer(), + ResourcePath: s.GetResourcePath(), + TenantUUID: s.GetTenantUuid(), + } + if s.GetTenantIntId() != 0 { + res.Claims.Scope[i].TenantID = strconv.FormatInt(s.GetTenantIntId(), 10) + } + } + return res, nil +} + +type statsHandler struct { + logger log.FieldLogger +} + +func (sh *statsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return ctx +} + +func (sh *statsHandler) HandleRPC(ctx context.Context, s stats.RPCStats) { +} + +func (sh *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + return ctx +} + +func (sh *statsHandler) HandleConn(ctx context.Context, s stats.ConnStats) { + switch s.(type) { + case *stats.ConnBegin: + sh.logger.Infof("grpc connection established") + case *stats.ConnEnd: + sh.logger.Infof("grpc connection closed") + } +} diff --git a/idptoken/idp_token.proto b/idptoken/idp_token.proto new file mode 100644 index 0000000..4299e6a --- /dev/null +++ b/idptoken/idp_token.proto @@ -0,0 +1,74 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +syntax = "proto3"; + +package idp_token; + +option go_package = "./pb"; + +service IDPTokenService { + // CreateToken creates a new token based on the provided assertion. + // Currently only "urn:ietf:params:oauth:grant-type:jwt-bearer" grant type is supported. + rpc CreateToken (CreateTokenRequest) returns (CreateTokenResponse); + + // IntrospectToken returns information about the token including its scopes. + // The token is considered active if + // 1) it's not expired; + // 2) it's not revoked; + // 3) it has has the valid signature. + rpc IntrospectToken (IntrospectTokenRequest) returns (IntrospectTokenResponse); +} + +message CreateTokenRequest { + reserved 4 to 50; + string grant_type = 1; // example: urn:ietf:params:oauth:grant-type:jwt-bearer + string assertion = 2; + uint32 token_version = 3; +} + +message CreateTokenResponse { + reserved 4 to 50; + string access_token = 1; + string token_type = 2; + int64 expires_in = 3; +} + +message IntrospectionScopeFilter { + reserved 2 to 50; + string ResourceNamespace = 1; +} + +message IntrospectTokenRequest { + reserved 3 to 50; + string token = 1; + repeated IntrospectionScopeFilter scope_filter = 2; +} + +message AccessTokenScope { + reserved 7 to 50; + string tenant_uuid = 1; + int64 tenant_int_id = 2; + string ResourceServer = 3; + string ResourceNamespace = 4; + string ResourcePath = 5; + string RoleName = 6; +} + +message IntrospectTokenResponse { + reserved 12 to 100; + bool active = 1; + string token_type = 2; + int64 exp = 3; + repeated string aud = 4; + string jti = 5; + string iss = 6; + string sub = 7; + string sub_type = 8; + string client_id = 9; + string owner_tenant_uuid = 10; // API client owner tenant UUID. + repeated AccessTokenScope scope = 11; +} \ No newline at end of file diff --git a/idptoken/introspector.go b/idptoken/introspector.go new file mode 100644 index 0000000..0475e01 --- /dev/null +++ b/idptoken/introspector.go @@ -0,0 +1,402 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" + "github.com/acronis/go-authkit/jwt" +) + +const DefaultRequestTimeout = 30 * time.Second + +const JWTTypeAccessToken = "at+jwt" + +const TokenTypeBearer = "bearer" + +const MinJWTVersionForIntrospection = 0 + +const minAccessTokenProviderInvalidationInterval = time.Minute + +const tokenIntrospectorPromSource = "token_introspector" + +// ErrTokenNotIntrospectable is returned when token is not introspectable. +var ErrTokenNotIntrospectable = errors.New("token is not introspectable") + +// ErrTokenIntrospectionNotNeeded is returned when token introspection is unnecessary +// (i.e., it already contains all necessary information). +var ErrTokenIntrospectionNotNeeded = errors.New("token introspection is not needed") + +// ErrTokenIntrospectionUnauthenticated is returned when token introspection is unauthenticated. +var ErrTokenIntrospectionUnauthenticated = errors.New("token introspection is unauthenticated") + +// TrustedIssNotFoundFallback is a function called when given issuer is not found in the list of trusted ones. +// For example, it could be analyzed and then added to the list by calling AddTrustedIssuerURL method. +type TrustedIssNotFoundFallback func(ctx context.Context, i *Introspector, iss string) (issURL string, issFound bool) + +// IntrospectionTokenProvider is an interface for getting access token for doing introspection. +// The token should have introspection permission. +type IntrospectionTokenProvider interface { + GetToken(ctx context.Context, scope ...string) (string, error) + Invalidate() +} + +// IntrospectionScopeFilterAccessPolicy is an access policy for filtering scopes. +type IntrospectionScopeFilterAccessPolicy struct { + ResourceNamespace string +} + +// IntrospectorOpts is a set of options for creating Introspector. +type IntrospectorOpts struct { + // GRPCClient is a GRPC client for doing introspection. + // If it is set, then introspection will be done using this client. + // Otherwise, introspection will be done via HTTP. + GRPCClient *GRPCClient + + // StaticHTTPEndpoint is a static URL for introspection. + // If it is set, then introspection will be done using this endpoint. + // Otherwise, introspection will be done using issuer URL (/.well-known/openid-configuration response). + // In this case, issuer URL should be present in JWT header or payload. + StaticHTTPEndpoint string + + // HTTPClient is an HTTP client for doing requests to /.well-known/openid-configuration and introspection endpoints. + HTTPClient *http.Client + + // AccessTokenScope is a scope for getting access token for doing introspection. + // The token should have introspection permission. + AccessTokenScope []string + + // ScopeFilter is a list of access policies for filtering scopes during introspection. + // If it is set, then only scopes that match at least one of the policies will be returned. + ScopeFilter []IntrospectionScopeFilterAccessPolicy + + // Logger is a logger for logging errors and debug information. + Logger log.FieldLogger + + // TrustedIssuerNotFoundFallback is a function called + // when given issuer from JWT is not found in the list of trusted ones. + TrustedIssuerNotFoundFallback TrustedIssNotFoundFallback + + // MinJWTVersion is a minimum JWT version for introspection. + // If JWT version is less than this value, then introspection will not be done + // and ErrTokenIntrospectionNotNeeded will be returned. + // Version is a value of "ver" field in JWT header. + // By default, it is 0. + // NOTE: it's a temporary solution for determining whether introspection is needed or not, + // and it will be removed in the future. + MinJWTVersion int + + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same library. + PrometheusLibInstanceLabel string +} + +// Introspector is a struct for introspecting tokens. +type Introspector struct { + accessTokenProvider IntrospectionTokenProvider + accessTokenProviderInvalidatedAt atomic.Value + accessTokenScope []string + + minJWTVersion int + jwtParser *jwtgo.Parser + + grpcClient *GRPCClient + staticHTTPURL string + httpClient *http.Client + + scopeFilter []IntrospectionScopeFilterAccessPolicy + scopeFilterFormURLEncoded string + + logger log.FieldLogger + + trustedIssuerStore *idputil.TrustedIssuerStore + trustedIssuerNotFoundFallback TrustedIssNotFoundFallback + + promMetrics *metrics.PrometheusMetrics +} + +// IntrospectionResult is a struct for introspection result. +type IntrospectionResult struct { + Active bool `json:"active"` + TokenType string `json:"token_type,omitempty"` + jwt.Claims +} + +// NewIntrospector creates a new Introspector with the given token provider. +func NewIntrospector(tokenProvider IntrospectionTokenProvider) *Introspector { + return NewIntrospectorWithOpts(tokenProvider, IntrospectorOpts{}) +} + +// NewIntrospectorWithOpts creates a new Introspector with the given token provider and options. +// See IntrospectorOpts for more details. +func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opts IntrospectorOpts) *Introspector { + if opts.HTTPClient == nil { + opts.HTTPClient = &http.Client{Timeout: DefaultRequestTimeout} + } + if opts.Logger == nil { + opts.Logger = log.NewDisabledLogger() + } + + values := url.Values{} + for i, policy := range opts.ScopeFilter { + values.Set("scope_filter["+strconv.Itoa(i)+"].rn", policy.ResourceNamespace) + } + scopeFilterFormURLEncoded := values.Encode() + + promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource) + + return &Introspector{ + accessTokenProvider: accessTokenProvider, + accessTokenScope: opts.AccessTokenScope, + jwtParser: jwtgo.NewParser(), + logger: opts.Logger, + grpcClient: opts.GRPCClient, + httpClient: opts.HTTPClient, + scopeFilterFormURLEncoded: scopeFilterFormURLEncoded, + scopeFilter: opts.ScopeFilter, + staticHTTPURL: opts.StaticHTTPEndpoint, + minJWTVersion: opts.MinJWTVersion, + trustedIssuerStore: idputil.NewTrustedIssuerStore(), + trustedIssuerNotFoundFallback: opts.TrustedIssuerNotFoundFallback, + promMetrics: promMetrics, + } +} + +// IntrospectToken introspects the given token. +func (i *Introspector) IntrospectToken(ctx context.Context, token string) (IntrospectionResult, error) { + introspectFn, err := i.makeIntrospectFuncForToken(ctx, token) + if err != nil { + return IntrospectionResult{}, err + } + + result, err := introspectFn(ctx, token) + if err == nil { + return result, nil + } + + if !errors.Is(err, ErrTokenIntrospectionUnauthenticated) { + return IntrospectionResult{}, err + } + + // If introspection is unauthorized, then invalidate access token (if it is not invalidated recently) and try again. + t, ok := i.accessTokenProviderInvalidatedAt.Load().(time.Time) + now := time.Now() + if !ok || now.Sub(t) > minAccessTokenProviderInvalidationInterval { + i.accessTokenProvider.Invalidate() + i.accessTokenProviderInvalidatedAt.Store(now) + return introspectFn(ctx, token) + } + return IntrospectionResult{}, err +} + +// AddTrustedIssuer adds trusted issuer with specified name and URL. +func (i *Introspector) AddTrustedIssuer(issName, issURL string) { + i.trustedIssuerStore.AddTrustedIssuer(issName, issURL) +} + +// AddTrustedIssuerURL adds trusted issuer URL. +func (i *Introspector) AddTrustedIssuerURL(issURL string) error { + return i.trustedIssuerStore.AddTrustedIssuerURL(issURL) +} + +type introspectFunc func(ctx context.Context, token string) (IntrospectionResult, error) + +func (i *Introspector) makeIntrospectFuncForToken(ctx context.Context, token string) (introspectFunc, error) { + var err error + + if token == "" { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("token is missing")) + } + + jwtHeaderEndIdx := strings.IndexByte(token, '.') + if jwtHeaderEndIdx == -1 { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("no JWT header found")) + } + var jwtHeaderBytes []byte + if jwtHeaderBytes, err = i.jwtParser.DecodeSegment(token[:jwtHeaderEndIdx]); err != nil { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("decode JWT header: %w", err)) + } + jwtHeader := make(map[string]interface{}) + if err = json.Unmarshal(jwtHeaderBytes, &jwtHeader); err != nil { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("unmarshal JWT header: %w", err)) + } + if typ, ok := jwtHeader["typ"].(string); !ok || !strings.EqualFold(typ, JWTTypeAccessToken) { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("token type is not %s", JWTTypeAccessToken)) + } + if ver, ok := jwtHeader["ver"].(float64); ok && int(ver) < i.minJWTVersion { + return nil, ErrTokenIntrospectionNotNeeded + } + + if i.staticHTTPURL != "" { + return i.makeIntrospectFuncHTTP(i.staticHTTPURL), nil + } + if i.grpcClient != nil { + return i.makeIntrospectFuncGRPC(), nil + } + + // Try to get issuer from JWT header first and then from JWT payload. + // Issuer is usually presented in the JWT payload (it's an optional field, according to RFC 7519), + // but it could be in the header as well for optimization purposes. + // It's relevant for JWTs with large payloads. + issuer, ok := jwtHeader["iss"].(string) + if !ok || issuer == "" { + jwtPayloadEndIdx := strings.IndexByte(token[jwtHeaderEndIdx+1:], '.') + if jwtPayloadEndIdx == -1 { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("no JWT payload found")) + } + var jwtPayloadBytes []byte + if jwtPayloadBytes, err = i.jwtParser.DecodeSegment( + token[jwtHeaderEndIdx+1 : jwtHeaderEndIdx+1+jwtPayloadEndIdx], + ); err != nil { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("decode JWT payload: %w", err)) + } + var originalClaims jwt.Claims + if err = json.Unmarshal(jwtPayloadBytes, &originalClaims); err != nil { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("unmarshal JWT payload: %w", err)) + } + if originalClaims.Issuer == "" { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("no issuer found in JWT")) + } + issuer = originalClaims.Issuer + } + + issuerURL, ok := i.getURLForIssuerWithCallback(ctx, issuer) + if !ok { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("issuer %q is not trusted", issuer)) + } + + // Try to get introspection endpoint URL from issuer. + introspectionEndpointURL, err := i.getWellKnownIntrospectionEndpointURL(ctx, issuerURL) + if err != nil { + return nil, fmt.Errorf("get introspection endpoint URL: %w", err) + } + return i.makeIntrospectFuncHTTP(introspectionEndpointURL), nil +} + +func (i *Introspector) makeStaticIntrospectFuncOrError(inner error) (introspectFunc, error) { + if i.grpcClient != nil { + return i.makeIntrospectFuncGRPC(), nil + } + if i.staticHTTPURL != "" { + return i.makeIntrospectFuncHTTP(i.staticHTTPURL), nil + } + return nil, makeTokenNotIntrospectableError(inner) +} + +func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) introspectFunc { + return func(ctx context.Context, token string) (IntrospectionResult, error) { + accessToken, err := i.accessTokenProvider.GetToken(ctx, i.accessTokenScope...) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("get access token for doing introspection: %w", err) + } + formEncoded := url.Values{"token": {token}}.Encode() + if i.scopeFilterFormURLEncoded != "" { + formEncoded += "&" + i.scopeFilterFormURLEncoded + } + req, err := http.NewRequest(http.MethodPost, introspectionEndpointURL, strings.NewReader(formEncoded)) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("new request: %w", err) + } + req.Header.Set("Authorization", makeBearerToken(accessToken)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + startTime := time.Now() + resp, err := i.httpClient.Do(req.WithContext(ctx)) + elapsed := time.Since(startTime) + if err != nil { + i.promMetrics.ObserveHTTPClientRequest(http.MethodPost, introspectionEndpointURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return IntrospectionResult{}, fmt.Errorf("do request: %w", err) + } + defer func() { + if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { + i.logger.Error(fmt.Sprintf("closing response body error for POST %s", introspectionEndpointURL), + log.Error(closeBodyErr)) + } + }() + if resp.StatusCode != http.StatusOK { + i.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + if resp.StatusCode == http.StatusUnauthorized { + return IntrospectionResult{}, ErrTokenIntrospectionUnauthenticated + } + return IntrospectionResult{}, fmt.Errorf("unexpected HTTP code %d for POST %s", resp.StatusCode, introspectionEndpointURL) + } + + var res IntrospectionResult + if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { + i.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return IntrospectionResult{}, fmt.Errorf("decode response body json for POST %s: %w", introspectionEndpointURL, err) + } + + i.promMetrics.ObserveHTTPClientRequest(http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, "") + return res, nil + } +} + +func (i *Introspector) makeIntrospectFuncGRPC() introspectFunc { + return func(ctx context.Context, token string) (IntrospectionResult, error) { + accessToken, err := i.accessTokenProvider.GetToken(ctx, i.accessTokenScope...) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("get access token for doing introspection: %w", err) + } + res, err := i.grpcClient.IntrospectToken(ctx, token, i.scopeFilter, accessToken) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("introspect token: %w", err) + } + return res, nil + } +} + +func (i *Introspector) getWellKnownIntrospectionEndpointURL(ctx context.Context, issuerURL string) (string, error) { + openIDCfgURL := strings.TrimSuffix(issuerURL, "/") + wellKnownPath + openIDCfg, err := idputil.GetOpenIDConfiguration( + ctx, i.httpClient, openIDCfgURL, nil, i.logger, i.promMetrics) + if err != nil { + return "", fmt.Errorf("get OpenID configuration: %w", err) + } + if openIDCfg.IntrospectionEndpoint == "" { + return "", fmt.Errorf("no introspection endpoint URL found on %s", openIDCfgURL) + } + return openIDCfg.IntrospectionEndpoint, nil +} + +func (i *Introspector) getURLForIssuerWithCallback(ctx context.Context, issuer string) (string, bool) { + issURL, issFound := i.trustedIssuerStore.GetURLForIssuer(issuer) + if issFound { + return issURL, true + } + if i.trustedIssuerNotFoundFallback == nil { + return "", false + } + return i.trustedIssuerNotFoundFallback(ctx, i, issuer) +} + +func makeTokenNotIntrospectableError(inner error) error { + if inner != nil { + return fmt.Errorf("%w: %w", ErrTokenNotIntrospectable, inner) + } + return ErrTokenNotIntrospectable +} + +func makeBearerToken(token string) string { + return "Bearer " + token +} diff --git a/idptoken/introspector_test.go b/idptoken/introspector_test.go new file mode 100644 index 0000000..5757b0c --- /dev/null +++ b/idptoken/introspector_test.go @@ -0,0 +1,306 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "context" + "net/http" + "net/url" + gotesting "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/credentials/insecure" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/testing" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +func TestIntrospector_IntrospectToken(t *gotesting.T) { + httpServerIntrospector := testing.NewHTTPServerTokenIntrospectorMock() + grpcServerIntrospector := testing.NewGRPCServerTokenIntrospectorMock() + + httpIDPSrv := idptest.NewHTTPServer(idptest.WithHTTPTokenIntrospector(httpServerIntrospector)) + require.NoError(t, httpIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = httpIDPSrv.Shutdown(context.Background()) }() + + grpcIDPSrv := idptest.NewGRPCServer(idptest.WithGRPCTokenIntrospector(grpcServerIntrospector)) + require.NoError(t, grpcIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { grpcIDPSrv.GracefulStop() }() + + const accessToken = "access-token-with-introspection-permission" + tokenProvider := idptest.NewSimpleTokenProvider(accessToken) + + logger := log.NewDisabledLogger() + jwtParser := jwt.NewParser(jwks.NewClient(http.DefaultClient, logger), logger) + require.NoError(t, jwtParser.AddTrustedIssuerURL(httpIDPSrv.URL())) + httpServerIntrospector.JWTParser = jwtParser + grpcServerIntrospector.JWTParser = jwtParser + + jwtScopeToGRPC := func(jwtScope []jwt.AccessPolicy) []*pb.AccessTokenScope { + grpcScope := make([]*pb.AccessTokenScope, len(jwtScope)) + for i, scope := range jwtScope { + grpcScope[i] = &pb.AccessTokenScope{ + TenantUuid: scope.TenantUUID, + ResourceNamespace: scope.ResourceNamespace, + RoleName: scope.Role, + ResourcePath: scope.ResourcePath, + } + } + return grpcScope + } + + jwtExpiresAtInFuture := jwtgo.NewNumericDate(time.Now().Add(time.Hour)) + jwtIssuer := httpIDPSrv.URL() + jwtSubject := uuid.NewString() + jwtID := uuid.NewString() + jwtScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-" + uuid.NewString(), + }} + + opaqueToken := "opaque-token-" + uuid.NewString() + opaqueTokenScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + + httpServerIntrospector.SetScopeForJWTID(jwtID, jwtScope) + httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) + grpcServerIntrospector.SetScopeForJWTID(jwtID, jwtScopeToGRPC(jwtScope)) + grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{ + Active: true, TokenType: idptoken.TokenTypeBearer, Scope: jwtScopeToGRPC(opaqueTokenScope)}) + + tests := []struct { + name string + introspectorOpts idptoken.IntrospectorOpts + useGRPC bool + token string + expectedResult idptoken.IntrospectionResult + checkError func(t *gotesting.T, err error) + expectedHTTPSrvCalled bool + expectedHTTPFormVals url.Values + expectedGRPCSrvCalled bool + expectedGRPCScopeFilter []*pb.IntrospectionScopeFilter + }{ + { + name: "error, token is missing", + token: "", + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "token is missing") + }, + }, + { + name: "error, dynamic introspection endpoint, no jwt header", + token: "opaque-token", + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "no JWT header found") + }, + }, + { + name: "error, dynamic introspection endpoint, cannot decode jwt header", + token: "$opaque$.$token$", + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "decode JWT header") + }, + }, + { + name: "error, dynamic introspection endpoint, issuer is not trusted", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "https://untrusted-issuer.com", + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + }, + }), + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, `issuer "https://untrusted-issuer.com" is not trusted`) + }, + }, + { + name: "error, dynamic introspection endpoint, issuer is missing in JWT header and payload", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + }, + }), + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "no issuer found in JWT") + }, + }, + { + name: "error, dynamic introspection endpoint, introspection is not needed", + introspectorOpts: idptoken.IntrospectorOpts{ + MinJWTVersion: 3, + }, + token: idptest.MustMakeTokenStringWithHeader(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + }, + }, idptest.TestKeyID, idptest.GetTestRSAPrivateKey(), map[string]interface{}{"ver": 2}), + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenIntrospectionNotNeeded) + }, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is expired JWT", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + }, + }), + expectedResult: idptoken.IntrospectionResult{Active: false}, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is JWT", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + }), + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + Scope: jwtScope, + }, + }, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, static introspection endpoint, introspected token is opaque", + introspectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, static introspection endpoint, introspected token is opaque, filter scope by resource namespace", + introspectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + ScopeFilter: []idptoken.IntrospectionScopeFilterAccessPolicy{ + {ResourceNamespace: "account-server"}, + {ResourceNamespace: "tenant-manager"}, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + expectedHTTPSrvCalled: true, + expectedHTTPFormVals: url.Values{ + "token": {opaqueToken}, + "scope_filter[0].rn": {"account-server"}, + "scope_filter[1].rn": {"tenant-manager"}, + }, + }, + { + name: "ok, grpc introspection endpoint", + useGRPC: true, + introspectorOpts: idptoken.IntrospectorOpts{ + ScopeFilter: []idptoken.IntrospectionScopeFilterAccessPolicy{ + {ResourceNamespace: "account-server"}, + {ResourceNamespace: "tenant-manager"}, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + expectedGRPCSrvCalled: true, + expectedGRPCScopeFilter: []*pb.IntrospectionScopeFilter{ + {ResourceNamespace: "account-server"}, + {ResourceNamespace: "tenant-manager"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + if tt.useGRPC { + grpcClient, err := idptoken.NewGRPCClient(grpcIDPSrv.Addr(), insecure.NewCredentials()) + require.NoError(t, err) + defer func() { require.NoError(t, grpcClient.Close()) }() + tt.introspectorOpts.GRPCClient = grpcClient + } + introspector := idptoken.NewIntrospectorWithOpts(tokenProvider, tt.introspectorOpts) + require.NoError(t, introspector.AddTrustedIssuerURL(httpIDPSrv.URL())) + + httpServerIntrospector.ResetCallsInfo() + grpcServerIntrospector.ResetCallsInfo() + + result, err := introspector.IntrospectToken(context.Background(), tt.token) + if tt.checkError != nil { + tt.checkError(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedResult, result) + } + + require.Equal(t, tt.expectedHTTPSrvCalled, httpServerIntrospector.Called) + if tt.expectedHTTPSrvCalled { + require.Equal(t, tt.token, httpServerIntrospector.LastIntrospectedToken) + require.Equal(t, "Bearer "+accessToken, httpServerIntrospector.LastAuthorizationHeader) + if tt.expectedHTTPFormVals == nil { + tt.expectedHTTPFormVals = url.Values{"token": {tt.token}} + } + require.Equal(t, tt.expectedHTTPFormVals, httpServerIntrospector.LastFormValues) + } + + require.Equal(t, tt.expectedGRPCSrvCalled, grpcServerIntrospector.Called) + if tt.expectedGRPCSrvCalled { + require.Equal(t, tt.token, grpcServerIntrospector.LastRequest.Token) + require.Equal(t, tt.expectedGRPCScopeFilter, grpcServerIntrospector.LastRequest.GetScopeFilter()) + require.Equal(t, "Bearer "+accessToken, grpcServerIntrospector.LastAuthorizationMeta) + } + }) + } +} diff --git a/idptoken/pb/idp_token.pb.go b/idptoken/pb/idp_token.pb.go new file mode 100644 index 0000000..0feb757 --- /dev/null +++ b/idptoken/pb/idp_token.pb.go @@ -0,0 +1,682 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.33.0 +// protoc v5.26.1 +// source: idp_token.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type CreateTokenRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + GrantType string `protobuf:"bytes,1,opt,name=grant_type,json=grantType,proto3" json:"grant_type,omitempty"` // example: urn:ietf:params:oauth:grant-type:jwt-bearer + Assertion string `protobuf:"bytes,2,opt,name=assertion,proto3" json:"assertion,omitempty"` + TokenVersion uint32 `protobuf:"varint,3,opt,name=token_version,json=tokenVersion,proto3" json:"token_version,omitempty"` +} + +func (x *CreateTokenRequest) Reset() { + *x = CreateTokenRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateTokenRequest) ProtoMessage() {} + +func (x *CreateTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateTokenRequest.ProtoReflect.Descriptor instead. +func (*CreateTokenRequest) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{0} +} + +func (x *CreateTokenRequest) GetGrantType() string { + if x != nil { + return x.GrantType + } + return "" +} + +func (x *CreateTokenRequest) GetAssertion() string { + if x != nil { + return x.Assertion + } + return "" +} + +func (x *CreateTokenRequest) GetTokenVersion() uint32 { + if x != nil { + return x.TokenVersion + } + return 0 +} + +type CreateTokenResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` + TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` + ExpiresIn int64 `protobuf:"varint,3,opt,name=expires_in,json=expiresIn,proto3" json:"expires_in,omitempty"` +} + +func (x *CreateTokenResponse) Reset() { + *x = CreateTokenResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateTokenResponse) ProtoMessage() {} + +func (x *CreateTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateTokenResponse.ProtoReflect.Descriptor instead. +func (*CreateTokenResponse) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{1} +} + +func (x *CreateTokenResponse) GetAccessToken() string { + if x != nil { + return x.AccessToken + } + return "" +} + +func (x *CreateTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *CreateTokenResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + +type IntrospectionScopeFilter struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ResourceNamespace string `protobuf:"bytes,1,opt,name=ResourceNamespace,proto3" json:"ResourceNamespace,omitempty"` +} + +func (x *IntrospectionScopeFilter) Reset() { + *x = IntrospectionScopeFilter{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IntrospectionScopeFilter) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IntrospectionScopeFilter) ProtoMessage() {} + +func (x *IntrospectionScopeFilter) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IntrospectionScopeFilter.ProtoReflect.Descriptor instead. +func (*IntrospectionScopeFilter) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{2} +} + +func (x *IntrospectionScopeFilter) GetResourceNamespace() string { + if x != nil { + return x.ResourceNamespace + } + return "" +} + +type IntrospectTokenRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + ScopeFilter []*IntrospectionScopeFilter `protobuf:"bytes,2,rep,name=scope_filter,json=scopeFilter,proto3" json:"scope_filter,omitempty"` +} + +func (x *IntrospectTokenRequest) Reset() { + *x = IntrospectTokenRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IntrospectTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IntrospectTokenRequest) ProtoMessage() {} + +func (x *IntrospectTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IntrospectTokenRequest.ProtoReflect.Descriptor instead. +func (*IntrospectTokenRequest) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{3} +} + +func (x *IntrospectTokenRequest) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *IntrospectTokenRequest) GetScopeFilter() []*IntrospectionScopeFilter { + if x != nil { + return x.ScopeFilter + } + return nil +} + +type AccessTokenScope struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TenantUuid string `protobuf:"bytes,1,opt,name=tenant_uuid,json=tenantUuid,proto3" json:"tenant_uuid,omitempty"` + TenantIntId int64 `protobuf:"varint,2,opt,name=tenant_int_id,json=tenantIntId,proto3" json:"tenant_int_id,omitempty"` + ResourceServer string `protobuf:"bytes,3,opt,name=ResourceServer,proto3" json:"ResourceServer,omitempty"` + ResourceNamespace string `protobuf:"bytes,4,opt,name=ResourceNamespace,proto3" json:"ResourceNamespace,omitempty"` + ResourcePath string `protobuf:"bytes,5,opt,name=ResourcePath,proto3" json:"ResourcePath,omitempty"` + RoleName string `protobuf:"bytes,6,opt,name=RoleName,proto3" json:"RoleName,omitempty"` +} + +func (x *AccessTokenScope) Reset() { + *x = AccessTokenScope{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AccessTokenScope) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AccessTokenScope) ProtoMessage() {} + +func (x *AccessTokenScope) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AccessTokenScope.ProtoReflect.Descriptor instead. +func (*AccessTokenScope) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{4} +} + +func (x *AccessTokenScope) GetTenantUuid() string { + if x != nil { + return x.TenantUuid + } + return "" +} + +func (x *AccessTokenScope) GetTenantIntId() int64 { + if x != nil { + return x.TenantIntId + } + return 0 +} + +func (x *AccessTokenScope) GetResourceServer() string { + if x != nil { + return x.ResourceServer + } + return "" +} + +func (x *AccessTokenScope) GetResourceNamespace() string { + if x != nil { + return x.ResourceNamespace + } + return "" +} + +func (x *AccessTokenScope) GetResourcePath() string { + if x != nil { + return x.ResourcePath + } + return "" +} + +func (x *AccessTokenScope) GetRoleName() string { + if x != nil { + return x.RoleName + } + return "" +} + +type IntrospectTokenResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Active bool `protobuf:"varint,1,opt,name=active,proto3" json:"active,omitempty"` + TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` + Exp int64 `protobuf:"varint,3,opt,name=exp,proto3" json:"exp,omitempty"` + Aud []string `protobuf:"bytes,4,rep,name=aud,proto3" json:"aud,omitempty"` + Jti string `protobuf:"bytes,5,opt,name=jti,proto3" json:"jti,omitempty"` + Iss string `protobuf:"bytes,6,opt,name=iss,proto3" json:"iss,omitempty"` + Sub string `protobuf:"bytes,7,opt,name=sub,proto3" json:"sub,omitempty"` + SubType string `protobuf:"bytes,8,opt,name=sub_type,json=subType,proto3" json:"sub_type,omitempty"` + ClientId string `protobuf:"bytes,9,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + OwnerTenantUuid string `protobuf:"bytes,10,opt,name=owner_tenant_uuid,json=ownerTenantUuid,proto3" json:"owner_tenant_uuid,omitempty"` // API client owner tenant UUID. + Scope []*AccessTokenScope `protobuf:"bytes,11,rep,name=scope,proto3" json:"scope,omitempty"` +} + +func (x *IntrospectTokenResponse) Reset() { + *x = IntrospectTokenResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IntrospectTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IntrospectTokenResponse) ProtoMessage() {} + +func (x *IntrospectTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IntrospectTokenResponse.ProtoReflect.Descriptor instead. +func (*IntrospectTokenResponse) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{5} +} + +func (x *IntrospectTokenResponse) GetActive() bool { + if x != nil { + return x.Active + } + return false +} + +func (x *IntrospectTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *IntrospectTokenResponse) GetExp() int64 { + if x != nil { + return x.Exp + } + return 0 +} + +func (x *IntrospectTokenResponse) GetAud() []string { + if x != nil { + return x.Aud + } + return nil +} + +func (x *IntrospectTokenResponse) GetJti() string { + if x != nil { + return x.Jti + } + return "" +} + +func (x *IntrospectTokenResponse) GetIss() string { + if x != nil { + return x.Iss + } + return "" +} + +func (x *IntrospectTokenResponse) GetSub() string { + if x != nil { + return x.Sub + } + return "" +} + +func (x *IntrospectTokenResponse) GetSubType() string { + if x != nil { + return x.SubType + } + return "" +} + +func (x *IntrospectTokenResponse) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *IntrospectTokenResponse) GetOwnerTenantUuid() string { + if x != nil { + return x.OwnerTenantUuid + } + return "" +} + +func (x *IntrospectTokenResponse) GetScope() []*AccessTokenScope { + if x != nil { + return x.Scope + } + return nil +} + +var File_idp_token_proto protoreflect.FileDescriptor + +var file_idp_token_proto_rawDesc = []byte{ + 0x0a, 0x0f, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x09, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x7c, 0x0a, 0x12, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x73, 0x73, 0x65, 0x72, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x73, 0x73, 0x65, 0x72, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x23, 0x0a, 0x0d, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x4a, 0x04, 0x08, 0x04, 0x10, 0x33, 0x22, 0x7c, 0x0a, 0x13, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74, 0x79, + 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x54, + 0x79, 0x70, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x69, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, + 0x49, 0x6e, 0x4a, 0x04, 0x08, 0x04, 0x10, 0x33, 0x22, 0x4e, 0x0a, 0x18, 0x49, 0x6e, 0x74, 0x72, + 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x46, 0x69, + 0x6c, 0x74, 0x65, 0x72, 0x12, 0x2c, 0x0a, 0x11, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x11, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, + 0x63, 0x65, 0x4a, 0x04, 0x08, 0x02, 0x10, 0x33, 0x22, 0x7c, 0x0a, 0x16, 0x49, 0x6e, 0x74, 0x72, + 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x46, 0x0a, 0x0c, 0x73, 0x63, 0x6f, 0x70, + 0x65, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, + 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, + 0x73, 0x70, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x46, 0x69, 0x6c, + 0x74, 0x65, 0x72, 0x52, 0x0b, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, + 0x4a, 0x04, 0x08, 0x03, 0x10, 0x33, 0x22, 0xf3, 0x01, 0x0a, 0x10, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x74, + 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x55, 0x75, 0x69, 0x64, 0x12, 0x22, 0x0a, 0x0d, + 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x0b, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x6e, 0x74, 0x49, 0x64, + 0x12, 0x26, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x2c, 0x0a, 0x11, 0x52, 0x65, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x11, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, + 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, + 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x1a, 0x0a, 0x08, 0x52, 0x6f, + 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x52, 0x6f, + 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x4a, 0x04, 0x08, 0x07, 0x10, 0x33, 0x22, 0xc7, 0x02, 0x0a, + 0x17, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, + 0x76, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, + 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x10, 0x0a, 0x03, 0x65, 0x78, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x65, 0x78, + 0x70, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x75, 0x64, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, + 0x61, 0x75, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6a, 0x74, 0x69, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x6a, 0x74, 0x69, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x73, 0x73, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x69, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x75, 0x62, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x75, 0x62, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x75, 0x62, + 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x75, 0x62, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x12, 0x2a, 0x0a, 0x11, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x5f, 0x74, 0x65, 0x6e, 0x61, 0x6e, + 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x6f, 0x77, + 0x6e, 0x65, 0x72, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x55, 0x75, 0x69, 0x64, 0x12, 0x31, 0x0a, + 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x69, + 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, + 0x4a, 0x04, 0x08, 0x0c, 0x10, 0x65, 0x32, 0xb9, 0x01, 0x0a, 0x0f, 0x49, 0x44, 0x50, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4c, 0x0a, 0x0b, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1d, 0x2e, 0x69, 0x64, 0x70, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x58, 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x72, + 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x21, 0x2e, 0x69, 0x64, + 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, + 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, + 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, + 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_idp_token_proto_rawDescOnce sync.Once + file_idp_token_proto_rawDescData = file_idp_token_proto_rawDesc +) + +func file_idp_token_proto_rawDescGZIP() []byte { + file_idp_token_proto_rawDescOnce.Do(func() { + file_idp_token_proto_rawDescData = protoimpl.X.CompressGZIP(file_idp_token_proto_rawDescData) + }) + return file_idp_token_proto_rawDescData +} + +var file_idp_token_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_idp_token_proto_goTypes = []interface{}{ + (*CreateTokenRequest)(nil), // 0: idp_token.CreateTokenRequest + (*CreateTokenResponse)(nil), // 1: idp_token.CreateTokenResponse + (*IntrospectionScopeFilter)(nil), // 2: idp_token.IntrospectionScopeFilter + (*IntrospectTokenRequest)(nil), // 3: idp_token.IntrospectTokenRequest + (*AccessTokenScope)(nil), // 4: idp_token.AccessTokenScope + (*IntrospectTokenResponse)(nil), // 5: idp_token.IntrospectTokenResponse +} +var file_idp_token_proto_depIdxs = []int32{ + 2, // 0: idp_token.IntrospectTokenRequest.scope_filter:type_name -> idp_token.IntrospectionScopeFilter + 4, // 1: idp_token.IntrospectTokenResponse.scope:type_name -> idp_token.AccessTokenScope + 0, // 2: idp_token.IDPTokenService.CreateToken:input_type -> idp_token.CreateTokenRequest + 3, // 3: idp_token.IDPTokenService.IntrospectToken:input_type -> idp_token.IntrospectTokenRequest + 1, // 4: idp_token.IDPTokenService.CreateToken:output_type -> idp_token.CreateTokenResponse + 5, // 5: idp_token.IDPTokenService.IntrospectToken:output_type -> idp_token.IntrospectTokenResponse + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_idp_token_proto_init() } +func file_idp_token_proto_init() { + if File_idp_token_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_idp_token_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateTokenRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateTokenResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IntrospectionScopeFilter); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IntrospectTokenRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AccessTokenScope); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IntrospectTokenResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_idp_token_proto_rawDesc, + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_idp_token_proto_goTypes, + DependencyIndexes: file_idp_token_proto_depIdxs, + MessageInfos: file_idp_token_proto_msgTypes, + }.Build() + File_idp_token_proto = out.File + file_idp_token_proto_rawDesc = nil + file_idp_token_proto_goTypes = nil + file_idp_token_proto_depIdxs = nil +} diff --git a/idptoken/pb/idp_token_grpc.pb.go b/idptoken/pb/idp_token_grpc.pb.go new file mode 100644 index 0000000..a338bb6 --- /dev/null +++ b/idptoken/pb/idp_token_grpc.pb.go @@ -0,0 +1,154 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.3.0 +// - protoc v5.26.1 +// source: idp_token.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +const ( + IDPTokenService_CreateToken_FullMethodName = "/idp_token.IDPTokenService/CreateToken" + IDPTokenService_IntrospectToken_FullMethodName = "/idp_token.IDPTokenService/IntrospectToken" +) + +// IDPTokenServiceClient is the client API for IDPTokenService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type IDPTokenServiceClient interface { + // CreateToken creates a new token based on the provided assertion. + // Now only JWT_BEARER (urn:ietf:params:oauth:grant-type:jwt-bearer) grant type is supported. + CreateToken(ctx context.Context, in *CreateTokenRequest, opts ...grpc.CallOption) (*CreateTokenResponse, error) + // IntrospectToken returns information about the token including its scopes. + // The token is considered active if it 1) is not expired; 2) is not revoked; 3) has has the valid signature. + IntrospectToken(ctx context.Context, in *IntrospectTokenRequest, opts ...grpc.CallOption) (*IntrospectTokenResponse, error) +} + +type iDPTokenServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewIDPTokenServiceClient(cc grpc.ClientConnInterface) IDPTokenServiceClient { + return &iDPTokenServiceClient{cc} +} + +func (c *iDPTokenServiceClient) CreateToken(ctx context.Context, in *CreateTokenRequest, opts ...grpc.CallOption) (*CreateTokenResponse, error) { + out := new(CreateTokenResponse) + err := c.cc.Invoke(ctx, IDPTokenService_CreateToken_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *iDPTokenServiceClient) IntrospectToken(ctx context.Context, in *IntrospectTokenRequest, opts ...grpc.CallOption) (*IntrospectTokenResponse, error) { + out := new(IntrospectTokenResponse) + err := c.cc.Invoke(ctx, IDPTokenService_IntrospectToken_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// IDPTokenServiceServer is the server API for IDPTokenService service. +// All implementations must embed UnimplementedIDPTokenServiceServer +// for forward compatibility +type IDPTokenServiceServer interface { + // CreateToken creates a new token based on the provided assertion. + // Now only JWT_BEARER (urn:ietf:params:oauth:grant-type:jwt-bearer) grant type is supported. + CreateToken(context.Context, *CreateTokenRequest) (*CreateTokenResponse, error) + // IntrospectToken returns information about the token including its scopes. + // The token is considered active if it 1) is not expired; 2) is not revoked; 3) has has the valid signature. + IntrospectToken(context.Context, *IntrospectTokenRequest) (*IntrospectTokenResponse, error) + mustEmbedUnimplementedIDPTokenServiceServer() +} + +// UnimplementedIDPTokenServiceServer must be embedded to have forward compatible implementations. +type UnimplementedIDPTokenServiceServer struct { +} + +func (UnimplementedIDPTokenServiceServer) CreateToken(context.Context, *CreateTokenRequest) (*CreateTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateToken not implemented") +} +func (UnimplementedIDPTokenServiceServer) IntrospectToken(context.Context, *IntrospectTokenRequest) (*IntrospectTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method IntrospectToken not implemented") +} +func (UnimplementedIDPTokenServiceServer) mustEmbedUnimplementedIDPTokenServiceServer() {} + +// UnsafeIDPTokenServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to IDPTokenServiceServer will +// result in compilation errors. +type UnsafeIDPTokenServiceServer interface { + mustEmbedUnimplementedIDPTokenServiceServer() +} + +func RegisterIDPTokenServiceServer(s grpc.ServiceRegistrar, srv IDPTokenServiceServer) { + s.RegisterService(&IDPTokenService_ServiceDesc, srv) +} + +func _IDPTokenService_CreateToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(IDPTokenServiceServer).CreateToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: IDPTokenService_CreateToken_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(IDPTokenServiceServer).CreateToken(ctx, req.(*CreateTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _IDPTokenService_IntrospectToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(IntrospectTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(IDPTokenServiceServer).IntrospectToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: IDPTokenService_IntrospectToken_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(IDPTokenServiceServer).IntrospectToken(ctx, req.(*IntrospectTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// IDPTokenService_ServiceDesc is the grpc.ServiceDesc for IDPTokenService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var IDPTokenService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "idp_token.IDPTokenService", + HandlerType: (*IDPTokenServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "CreateToken", + Handler: _IDPTokenService_CreateToken_Handler, + }, + { + MethodName: "IntrospectToken", + Handler: _IDPTokenService_IntrospectToken_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "idp_token.proto", +} diff --git a/idptoken/provider.go b/idptoken/provider.go new file mode 100644 index 0000000..80f44a5 --- /dev/null +++ b/idptoken/provider.go @@ -0,0 +1,692 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "math/big" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/log" + "golang.org/x/sync/singleflight" + + "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" +) + +const ( + defaultMinRefreshPeriod = time.Second * 10 + defaultExpirationOffset = time.Minute * 30 + expiryDeltaMaxOffset = 5 + wellKnownPath = "/.well-known/openid-configuration" +) + +var ( + // ErrSourceNotRegistered is returned if GetToken is requested for the unknown Source + ErrSourceNotRegistered = errors.New("cannot issue token for unknown source") +) + +// UnexpectedIDPResponseError is an error representing an unexpected response +type UnexpectedIDPResponseError struct { + HTTPCode int + IssueURL string +} + +func (e *UnexpectedIDPResponseError) Error() string { + return fmt.Sprintf(`%s responded with unexpected code %d`, e.IssueURL, e.HTTPCode) +} + +// TokenData represents API-related token information +type TokenData struct { + Data string + ClientID string + issueURL string + Scope []string + Expires time.Time +} + +// Source serves to provide auth source information to MultiSourceProvider and Provider +type Source struct { + URL string + ClientID string + ClientSecret string +} + +var zeroTime = time.Time{} + +// TokenDetails represents the data to be stored in TokenCache +type TokenDetails struct { + token TokenData + requestedScope []string + sourceURL string + issued time.Time + nextRefresh time.Time + invalidation time.Time +} + +// TokenCache is a cache entry used to store TokenDetails based on a string key +type TokenCache interface { + // Get returns a value from the cache by key. + Get(key string) *TokenDetails + + // Put sets a new value to the cache by key. + Put(key string, val *TokenDetails) + + // Delete removes a value from the cache by key. + Delete(key string) + + // ClearAll removes all values from the cache. + ClearAll() + + // Keys returns all keys from the cache. + Keys() []string +} + +type InMemoryTokenCache struct { + mu sync.RWMutex + items map[string]*TokenDetails +} + +func NewInMemoryTokenCache() *InMemoryTokenCache { + return &InMemoryTokenCache{items: make(map[string]*TokenDetails)} +} + +func (c *InMemoryTokenCache) Keys() []string { + c.mu.RLock() + defer c.mu.RUnlock() + result := make([]string, 0, len(c.items)) + for k := range c.items { + result = append(result, k) + } + return result +} + +func (c *InMemoryTokenCache) Get(key string) *TokenDetails { + c.mu.RLock() + defer c.mu.RUnlock() + item, found := c.items[key] + if !found { + return nil + } + return item +} + +func (c *InMemoryTokenCache) Put(key string, val *TokenDetails) { + c.mu.Lock() + defer c.mu.Unlock() + c.items[key] = val +} + +func (c *InMemoryTokenCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.items, key) +} + +func (c *InMemoryTokenCache) ClearAll() { + c.mu.Lock() + defer c.mu.Unlock() + c.items = make(map[string]*TokenDetails) +} + +// MultiSourceProvider is a caching token provider for multiple datacenters and clients +type MultiSourceProvider struct { + tokenIssuers map[string]*oauth2Issuer + + rescheduleSignal chan struct{} + minRefreshPeriod time.Duration + httpClient *http.Client + logger log.FieldLogger + promMetrics *metrics.PrometheusMetrics + + cache TokenCache + customHeaders map[string]string + sfGroup singleflight.Group + + nextRefreshMu sync.RWMutex + nextRefresh time.Time +} + +// NewMultiSourceProviderWithOpts returns a new instance of MultiSourceProvider with custom settings +func NewMultiSourceProviderWithOpts( + httpClient *http.Client, opts ProviderOpts, sources ...Source, +) *MultiSourceProvider { + p := MultiSourceProvider{} + + if opts.Logger == nil { + opts.Logger = log.NewDisabledLogger() + } + + if opts.MinRefreshPeriod == 0 { + opts.MinRefreshPeriod = defaultMinRefreshPeriod + } + + p.init(httpClient, opts, sources...) + return &p +} + +// NewMultiSourceProvider returns a new instance of MultiSourceProvider with default settings +func NewMultiSourceProvider(httpClient *http.Client) *MultiSourceProvider { + return NewMultiSourceProviderWithOpts(httpClient, ProviderOpts{}) +} + +// RegisterSource allows registering a new Source into MultiSourceProvider +func (p *MultiSourceProvider) RegisterSource(source Source) { + key := keyForIssuer(source.ClientID, source.URL) + if iss, found := p.tokenIssuers[key]; found { + if iss.clientSecret != source.ClientSecret { + iss.clientSecret = source.ClientSecret + p.cache.ClearAll() + } + } + newIssuer := p.newOAuth2Issuer(source.URL, source.ClientID, source.ClientSecret) + p.tokenIssuers[keyForIssuer(source.ClientID, source.URL)] = newIssuer +} + +// GetToken returns raw token for `clientID`, `sourceURL` and `scope` +func (p *MultiSourceProvider) GetToken( + ctx context.Context, clientID, sourceURL string, scope ...string, +) (string, error) { + return p.GetTokenWithHeaders(ctx, clientID, sourceURL, nil, scope...) +} + +// GetTokenWithHeaders returns raw token for `clientID`, `sourceURL` and `scope` while using `headers` +func (p *MultiSourceProvider) GetTokenWithHeaders( + ctx context.Context, clientID, sourceURL string, headers map[string]string, scope ...string, +) (string, error) { + return p.ensureToken(ctx, clientID, sourceURL, headers, scope) +} + +// Invalidate fully invalidates all tokens cache +func (p *MultiSourceProvider) Invalidate() { + p.cache.ClearAll() + p.setNextRefreshSafe(zeroTime) +} + +// RefreshTokensPeriodically starts a goroutine which refreshes tokens +func (p *MultiSourceProvider) RefreshTokensPeriodically(ctx context.Context) { + p.refreshLoop(ctx) +} + +func (p *MultiSourceProvider) issueToken( + ctx context.Context, clientID, sourceURL string, customHeaders map[string]string, scope []string, +) (TokenData, error) { + issuer, found := p.tokenIssuers[keyForIssuer(clientID, sourceURL)] + + if !found { + return TokenData{}, ErrSourceNotRegistered + } + + headers := make(map[string]string) + for k := range p.customHeaders { + headers[k] = p.customHeaders[k] + } + for k := range customHeaders { + headers[k] = customHeaders[k] + } + + _, errEns, _ := p.sfGroup.Do(keyForIssuer(clientID, sourceURL), func() (interface{}, error) { + return nil, issuer.EnsureIssuerURL(ctx, headers) + }) + + if errEns != nil { + p.logger.Error(fmt.Sprintf("(%s, %s): ensure issuer URL", sourceURL, clientID), log.Error(errEns)) + return TokenData{}, errEns + } + + sortedScope := uniqAndSort(scope) + key := keyForCache(clientID, issuer.loadIssuerURL(), sortedScope) + + token, err, _ := p.sfGroup.Do(key, func() (interface{}, error) { + result, issErr := issuer.IssueToken(ctx, headers, sortedScope) + p.cacheToken(result, sourceURL) + return result, issErr + }) + + if err != nil { + p.logger.Error(fmt.Sprintf("(%s, %s): issuing token", issuer.loadIssuerURL(), clientID), log.Error(err)) + return TokenData{}, err + } + + return token.(TokenData), nil +} + +func (p *MultiSourceProvider) ensureToken( + ctx context.Context, clientID, sourceURL string, customHeaders map[string]string, scope []string, +) (string, error) { + token, err := p.getCachedOrInvalidate(clientID, sourceURL, scope) + if err == nil { + return token.Data, nil + } + p.logger.Infof("(%s, %s): could not get token from cache: %v", sourceURL, clientID, err.Error()) + + token, err = p.issueToken(ctx, clientID, sourceURL, customHeaders, scope) + + if err != nil { + return "", err + } + return token.Data, nil +} + +func (p *MultiSourceProvider) cacheToken(token TokenData, sourceURL string) { + issued := time.Now().UTC() + randInt, err := rand.Int(rand.Reader, big.NewInt(expiryDeltaMaxOffset)) + if err != nil { + p.logger.Error("rand init error", log.Error(err)) + return + } + deltaMinutes := time.Minute * time.Duration(randInt.Int64()) + realExpiration := token.Expires.Sub(issued) + refreshDuration := token.Expires.Sub(issued) - defaultExpirationOffset - deltaMinutes + if realExpiration < defaultExpirationOffset { + refreshDuration = realExpiration / 5 + } + + nextRefresh := issued.Add(refreshDuration) + invalidation := issued.Add(realExpiration) + details := &TokenDetails{ + token: token, + issued: issued, + nextRefresh: nextRefresh, + invalidation: invalidation, + sourceURL: sourceURL, + } + + key := keyForCache(token.ClientID, token.issueURL, uniqAndSort(token.Scope)) + p.cache.Put(key, details) + pNextRefresh := p.getNextRefreshSafe() + if pNextRefresh == zeroTime || nextRefresh.UnixNano() <= pNextRefresh.UnixNano() { + p.setNextRefreshSafe(nextRefresh) + select { + case p.rescheduleSignal <- struct{}{}: + default: + } + } +} + +func (p *MultiSourceProvider) getCachedOrInvalidate(clientID, sourceURL string, scope []string) (TokenData, error) { + now := time.Now().UnixNano() + issuer, found := p.tokenIssuers[keyForIssuer(clientID, sourceURL)] + if !found { + return TokenData{}, fmt.Errorf("(%s, %s): not registered", sourceURL, clientID) + } + if issuer.loadIssuerURL() == "" { + return TokenData{}, fmt.Errorf("(%s, %s): issuer URL not acquired", sourceURL, clientID) + } + + key := keyForCache(clientID, issuer.loadIssuerURL(), uniqAndSort(scope)) + details := p.cache.Get(key) + if details == nil { + return TokenData{}, errors.New("token not found in cache") + } + if details.token.Expires.UnixNano() < now { + p.cache.Delete(key) + return TokenData{}, errors.New("token is expired") + } + if details.invalidation.UnixNano() < now { + p.cache.Delete(key) + return TokenData{}, errors.New("token needs to be refreshed") + } + if details.issued.UnixNano() > now { + p.cache.Delete(key) + return TokenData{}, errors.New("token's issued time is invalid") + } + return details.token, nil +} + +func (p *MultiSourceProvider) setNextRefreshSafe(nextRefresh time.Time) { + p.nextRefreshMu.Lock() + p.nextRefresh = nextRefresh + p.nextRefreshMu.Unlock() +} + +func (p *MultiSourceProvider) getNextRefreshSafe() time.Time { + p.nextRefreshMu.RLock() + nextRefresh := p.nextRefresh + p.nextRefreshMu.RUnlock() + return nextRefresh +} + +func (p *MultiSourceProvider) refreshTokens(ctx context.Context) { + now := time.Now().UTC() + + resultMap := make(map[*TokenDetails]struct{}) + nextRefresh := zeroTime + for _, key := range p.cache.Keys() { + details := p.cache.Get(key) + if details == nil { + continue + } + if details.nextRefresh.UnixNano() <= now.UnixNano() { + resultMap[details] = struct{}{} + continue + } + if nextRefresh == zeroTime { + nextRefresh = details.nextRefresh + } + if details.nextRefresh.UnixNano() <= nextRefresh.UnixNano() { + nextRefresh = details.nextRefresh + } + } + p.setNextRefreshSafe(nextRefresh) + toRefresh := make([]*TokenDetails, 0, len(resultMap)) + for token := range resultMap { + toRefresh = append(toRefresh, token) + } + + for _, details := range toRefresh { + _, err := p.issueToken(ctx, details.token.ClientID, details.sourceURL, nil, details.requestedScope) + if err != nil { + p.setNextRefreshSafe(now) + p.logger.Error( + fmt.Sprintf("(%s, %s): refresh error", details.sourceURL, details.token.ClientID), log.Error(err), + ) + } + } +} + +func (p *MultiSourceProvider) refreshLoop(ctx context.Context) { + t := time.NewTimer(time.Hour) + if !t.Stop() { + <-t.C + } + stopped := true + lastRefresh := time.Now().UTC() + currentRefresh := zeroTime + scheduleNext := func() { + nextRefresh := p.getNextRefreshSafe() + + currentRefresh = nextRefresh + if nextRefresh == zeroTime { + stopped = true + return + } + + now := time.Now().UTC() + next := nextRefresh.Sub(now) + if nextRefresh.Sub(lastRefresh) < p.minRefreshPeriod { + next = lastRefresh.Add(p.minRefreshPeriod).Sub(now) + } + + stopped = false + t.Reset(next) + } + scheduleNext() + for { + select { + case <-t.C: + lastRefresh = time.Now().UTC() + p.refreshTokens(ctx) + scheduleNext() + case <-p.rescheduleSignal: + nextRefresh := p.getNextRefreshSafe() + + if currentRefresh != nextRefresh { + if !stopped && !t.Stop() { + <-t.C + } + + if stopped { + // Token was issued a moment ago. + lastRefresh = time.Now().UTC() + } + + scheduleNext() + } + case <-ctx.Done(): + if !stopped && !t.Stop() { + <-t.C + } + return + } + } +} + +func (p *MultiSourceProvider) init(httpClient *http.Client, opts ProviderOpts, sources ...Source) { + if httpClient == nil { + panic("httpClient is mandatory") + } + p.cache = opts.CustomCacheInstance + if p.cache == nil { + p.cache = NewInMemoryTokenCache() + } + p.rescheduleSignal = make(chan struct{}, 1) + p.nextRefresh = zeroTime + p.minRefreshPeriod = opts.MinRefreshPeriod + p.logger = opts.Logger + p.httpClient = httpClient + p.tokenIssuers = make(map[string]*oauth2Issuer) + p.promMetrics = metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider") + p.customHeaders = opts.CustomHeaders + + for _, source := range sources { + p.RegisterSource(source) + } +} + +// Provider is a caching token provider for a single credentials set +type Provider struct { + provider *MultiSourceProvider + source Source +} + +// NewProvider returns a new instance of Provider +func NewProvider(httpClient *http.Client, source Source) *Provider { + return NewProviderWithOpts(httpClient, ProviderOpts{}, source) +} + +// NewProviderWithOpts returns a new instance of Provider with custom options +func NewProviderWithOpts(httpClient *http.Client, opts ProviderOpts, source Source) *Provider { + mp := Provider{ + source: source, + provider: NewMultiSourceProviderWithOpts(httpClient, opts, source), + } + return &mp +} + +// RefreshTokensPeriodically starts a goroutine which refreshes tokens +func (mp *Provider) RefreshTokensPeriodically(ctx context.Context) { + mp.provider.RefreshTokensPeriodically(ctx) +} + +// GetToken returns raw token for `scope` +func (mp *Provider) GetToken( + ctx context.Context, scope ...string, +) (string, error) { + return mp.provider.GetToken(ctx, mp.source.ClientID, mp.source.URL, scope...) +} + +// GetTokenWithHeaders returns raw token for `scope` while using `headers` +func (mp *Provider) GetTokenWithHeaders( + ctx context.Context, headers map[string]string, scope ...string, +) (string, error) { + return mp.provider.GetTokenWithHeaders(ctx, mp.source.ClientID, mp.source.URL, headers, scope...) +} + +func (mp *Provider) Invalidate() { + mp.provider.Invalidate() +} + +type oauth2Issuer struct { + baseURL string + clientID string + clientSecret string + httpClient *http.Client + logger log.FieldLogger + issuerURL atomic.Value + promMetrics *metrics.PrometheusMetrics +} + +func (p *MultiSourceProvider) newOAuth2Issuer(baseURL, clientID, clientSecret string) *oauth2Issuer { + return &oauth2Issuer{ + baseURL: baseURL, + clientID: clientID, + clientSecret: clientSecret, + httpClient: p.httpClient, + logger: p.logger, + promMetrics: p.promMetrics, + } +} + +func (ti *oauth2Issuer) loadIssuerURL() string { + if v := ti.issuerURL.Load(); v != nil { + return v.(string) + } + return "" +} + +func (ti *oauth2Issuer) EnsureIssuerURL(ctx context.Context, customHeaders map[string]string) error { + if ti.loadIssuerURL() != "" { + return nil + } + + openIDCfgURL := strings.TrimSuffix(ti.baseURL, "/") + wellKnownPath + openIDCfg, err := idputil.GetOpenIDConfiguration( + ctx, ti.httpClient, openIDCfgURL, customHeaders, ti.logger, ti.promMetrics) + if err != nil { + return fmt.Errorf("(%s, %s): get OpenID configuration: %w", ti.baseURL, ti.clientID, err) + } + + if _, err = url.ParseRequestURI(openIDCfg.TokenURL); err != nil { + return fmt.Errorf("(%s, %s): issuer have returned a non-valid URL %q: %w", + ti.baseURL, ti.clientID, openIDCfg.TokenURL, err) + } + ti.issuerURL.Store(openIDCfg.TokenURL) + return nil +} + +func (ti *oauth2Issuer) IssueToken( + ctx context.Context, customHeaders map[string]string, scope []string, +) (TokenData, error) { + issuerURL := ti.loadIssuerURL() + if issuerURL == "" { + panic("must first ensure issuerURL") + } + values := url.Values{} + values.Add("grant_type", "client_credentials") + scopeStr := strings.Join(scope, " ") + if scopeStr != "" { + values.Add("scope", scopeStr) + } + req, reqErr := http.NewRequest(http.MethodPost, issuerURL, strings.NewReader(values.Encode())) + if reqErr != nil { + return TokenData{}, reqErr + } + req = req.WithContext(ctx) + req.SetBasicAuth(ti.clientID, ti.clientSecret) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + for key := range customHeaders { + req.Header.Add(key, customHeaders[key]) + } + start := time.Now() + resp, err := ti.httpClient.Do(req) + elapsed := time.Since(start) + + if err != nil { + ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, issuerURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return TokenData{}, fmt.Errorf("do http request: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + ti.logger.Error( + fmt.Sprintf("(%s, %s): closing body", ti.loadIssuerURL(), ti.clientID), log.Error(err), + ) + } + }() + + tokenResponse := tokenResponseBody{} + if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + ti.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, issuerURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return TokenData{}, fmt.Errorf( + "(%s, %s): read and unmarshal IDP response: %w", ti.loadIssuerURL(), ti.clientID, err, + ) + } + + if resp.StatusCode != http.StatusOK { + ti.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, issuerURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return TokenData{}, &UnexpectedIDPResponseError{HTTPCode: resp.StatusCode, IssueURL: ti.loadIssuerURL()} + } + + ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, issuerURL, resp.StatusCode, elapsed, "") + expires := time.Now().Add(time.Second * time.Duration(tokenResponse.ExpiresIn)) + ti.logger.Infof("(%s, %s): issued token, expires on %s", ti.loadIssuerURL(), ti.clientID, expires.UTC()) + return TokenData{ + Data: tokenResponse.AccessToken, + Scope: scope, + Expires: expires, + issueURL: ti.loadIssuerURL(), + ClientID: ti.clientID, + }, nil +} + +// ProviderOpts represents options for creating a new MultiSourceProvider +type ProviderOpts struct { + // Logger is a logger for MultiSourceProvider. + Logger log.FieldLogger + + // MinRefreshPeriod is a minimal possible refresh interval for MultiSourceProvider's token cache. + MinRefreshPeriod time.Duration + + // CustomHeaders is a map of custom headers to be used in all HTTP requests. + CustomHeaders map[string]string + + // CustomCacheInstance is a custom token cache instance to be used in MultiSourceProvider. + CustomCacheInstance TokenCache + + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same service. + PrometheusLibInstanceLabel string +} + +func uniqAndSort(s []string) []string { + uniq := make(map[string]struct{}) + for ix := range s { + uniq[s[ix]] = struct{}{} + } + result := make([]string, 0, len(uniq)) + for k := range uniq { + result = append(result, k) + } + sort.Strings(result) + return result +} + +func keyForCache(clientID, sourceURL string, scope []string) string { + return clientID + ":" + sourceURL + ":" + strings.Join(scope, ",") +} + +func keyForIssuer(clientID, sourceURL string) string { + return sourceURL + ":" + clientID +} + +type tokenResponseBody struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope,omitempty"` + // not empty if token scope is different + // from the requested scope. Is equal to + // serialized token scope claim. Returned + // explicitly so client can know token + // scope w/o token parsing. Useful for + // middleware token response processing + ExpiresIn int `json:"expires_in"` + + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} diff --git a/idptoken/provider_test.go b/idptoken/provider_test.go new file mode 100644 index 0000000..2f1a572 --- /dev/null +++ b/idptoken/provider_test.go @@ -0,0 +1,464 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/acronis/go-appkit/httpclient" + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/testutil" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/metrics" + "github.com/acronis/go-authkit/jwt" +) + +const ( + expectedUserAgent = "Token MultiSourceProvider/1.0" + expectedXRequestID = "test" + testClientID = "89cadd1f-8649-4531-8b1d-a25de5aa3cd6" + defaultTestTokenExpirationTime = 2 +) + +type tTokenResponseBody struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope,omitempty"` + ExpiresIn int `json:"expires_in"` + Error string `json:"error"` +} + +type tFailingIDPTokenHandler struct{} + +func (h *tFailingIDPTokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + response := tTokenResponseBody{ + Error: "server_error", + } + encoder := json.NewEncoder(rw) + err := encoder.Encode(response) + if err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +type tHeaderCheckingIDPTokenHandler struct { + t *testing.T +} + +func (h *tHeaderCheckingIDPTokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + require.Equal(h.t, expectedUserAgent, r.Header.Get("User-Agent")) + require.Equal(h.t, expectedXRequestID, r.Header.Get("X-Request-ID")) + rw.WriteHeader(http.StatusOK) + response := tTokenResponseBody{ + AccessToken: "success", + ExpiresIn: defaultTestTokenExpirationTime, + Scope: "tenants:viewer", + } + encoder := json.NewEncoder(rw) + err := encoder.Encode(response) + if err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +func TestProviderWithCache(t *testing.T) { + tr, _ := httpclient.NewRetryableRoundTripperWithOpts( + http.DefaultTransport, httpclient.RetryableRoundTripperOpts{MaxRetryAttempts: 3}, + ) + httpClient := &http.Client{Transport: tr} + logger := log.NewDisabledLogger() + + t.Run("custom headers", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPTokenHandler(&tHeaderCheckingIDPTokenHandler{t}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + CustomHeaders: map[string]string{"User-Agent": expectedUserAgent}, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetTokenWithHeaders( + context.Background(), testClientID, server.URL(), + map[string]string{"X-Request-ID": expectedXRequestID}, "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("get token", func(t *testing.T) { + const tokenTTL = 2 * time.Second + + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: tokenTTL}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + cachedToken, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + + newToken, newTokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, newTokenErr) + require.Equal(t, cachedToken, newToken, "token was not cached") + time.Sleep(tokenTTL * 2) + + reissuedToken, reissuedTokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, reissuedTokenErr) + require.Greater(t, reissuedToken, cachedToken, "token was not re-issued") + }) + + t.Run("automatic refresh", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + + tokenOld, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + time.Sleep(3 * time.Second) + token, refreshErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, refreshErr) + require.Greater(t, token, tokenOld, "token should have already been refreshed") + }) + + t.Run("invalidate", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 10 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 10 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + + tokenOld, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + provider.Invalidate() + time.Sleep(1 * time.Second) + token, refreshErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, refreshErr) + require.Greater(t, token, tokenOld, "token should have already been refreshed") + }) + + t.Run("failing idp endpoint", func(t *testing.T) { + server := idptest.NewHTTPServer(idptest.WithHTTPTokenHandler(&tFailingIDPTokenHandler{})) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL() + "/weird", + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.Error(t, tokenErr) + labels := prometheus.Labels{ + metrics.HTTPClientRequestLabelMethod: http.MethodPost, + metrics.HTTPClientRequestLabelURL: server.URL() + idptest.TokenEndpointPath, + metrics.HTTPClientRequestLabelStatusCode: "500", + metrics.HTTPClientRequestLabelError: "unexpected_status_code", + } + promMetrics := metrics.GetPrometheusMetrics("", "token_provider") + hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram) + testutil.AssertSamplesCountInHistogram(t, hist, 1) + }) + + t.Run("metrics", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken(context.Background(), testClientID, server.URL(), "tenants:read") + require.NoError(t, tokenErr) + labels := prometheus.Labels{ + metrics.HTTPClientRequestLabelMethod: http.MethodPost, + metrics.HTTPClientRequestLabelURL: server.URL() + idptest.TokenEndpointPath, + metrics.HTTPClientRequestLabelStatusCode: "200", + metrics.HTTPClientRequestLabelError: "", + } + promMetrics := metrics.GetPrometheusMetrics("", "token_provider") + hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram) + testutil.AssertSamplesCountInHistogram(t, hist, 1) + }) + + t.Run("multiple sources", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + server2 := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + idptest.WithHTTPAddress(":8082"), + ) + require.NoError(t, server2.StartAndWaitForReady(time.Second)) + defer func() { _ = server2.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + }, + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXs", URL: server2.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + + _, tokenErr = provider.GetToken( + context.Background(), testClientID, server2.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("multiple sources", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + server2 := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + idptest.WithHTTPAddress(":8082"), + ) + require.NoError(t, server2.StartAndWaitForReady(time.Second)) + defer func() { _ = server2.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + }, + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXs", URL: server2.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials[0]) + go provider.RefreshTokensPeriodically(context.Background()) + provider.RegisterSource(credentials[1]) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server2.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("single source provider", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := idptoken.Source{ + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewProviderWithOpts(httpClient, opts, credentials) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken(context.Background(), "tenants:read") + require.NoError(t, tokenErr) + }) + + t.Run("start with no sources and register later", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := idptoken.Source{ + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + } + provider := idptoken.NewMultiSourceProvider(httpClient) + go provider.RefreshTokensPeriodically(context.Background()) + provider.RegisterSource(credentials) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("register source twice", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := idptoken.Source{ + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + } + tokenCache := idptoken.NewInMemoryTokenCache() + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, idptoken.ProviderOpts{CustomCacheInstance: tokenCache}) + go provider.RefreshTokensPeriodically(context.Background()) + provider.RegisterSource(credentials) + credentials.ClientSecret = "newsecret" + provider.RegisterSource(credentials) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + provider.RegisterSource(credentials) + require.Equal(t, 1, len(tokenCache.Keys()), "updating with same secret does not reset the cache") + credentials.ClientSecret = "evennewersecret" + provider.RegisterSource(credentials) + require.Equal(t, 0, len(tokenCache.Keys()), "updating with a new secret does reset the cache") + }) +} + +type claimsProviderWithExpiration struct { + ExpTime time.Duration +} + +func (d *claimsProviderWithExpiration) Provide(_ *http.Request) jwt.Claims { + claims := jwt.Claims{ + // nolint:staticcheck // StandardClaims are used here for test purposes + RegisteredClaims: jwtgo.RegisteredClaims{ + ID: uuid.NewString(), + IssuedAt: jwtgo.NewNumericDate(time.Now().UTC()), + }, + Scope: []jwt.AccessPolicy{ + { + TenantID: "1", + TenantUUID: uuid.NewString(), + Role: "tenant:viewer", + }, + }, + Version: 1, + UserID: "1", + } + + if d.ExpTime <= 0 { + d.ExpTime = 24 * time.Hour + } + claims.ExpiresAt = jwtgo.NewNumericDate(time.Now().UTC().Add(d.ExpTime)) + + return claims +} diff --git a/internal/idputil/doc.go b/internal/idputil/doc.go new file mode 100644 index 0000000..fc03f59 --- /dev/null +++ b/internal/idputil/doc.go @@ -0,0 +1,9 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package idputil provides utilities for working with identity providers. +// It's used in the internal code and not exposed to the public API. +package idputil diff --git a/internal/idputil/openid_configuration.go b/internal/idputil/openid_configuration.go new file mode 100644 index 0000000..370bb08 --- /dev/null +++ b/internal/idputil/openid_configuration.go @@ -0,0 +1,72 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idputil + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/acronis/go-appkit/log" + + "github.com/acronis/go-authkit/internal/metrics" +) + +type OpenIDConfiguration struct { + TokenURL string `json:"token_endpoint"` + IntrospectionEndpoint string `json:"introspection_endpoint"` + JWKSURI string `json:"jwks_uri"` +} + +func GetOpenIDConfiguration( + ctx context.Context, + httpClient *http.Client, + targetURL string, + additionalHeaders map[string]string, + logger log.FieldLogger, + promMetrics *metrics.PrometheusMetrics, +) (OpenIDConfiguration, error) { + req, err := http.NewRequest(http.MethodGet, targetURL, http.NoBody) + if err != nil { + return OpenIDConfiguration{}, fmt.Errorf("new request: %w", err) + } + for key, val := range additionalHeaders { + req.Header.Set(key, val) + } + + startTime := time.Now() + resp, err := httpClient.Do(req.WithContext(ctx)) + elapsed := time.Since(startTime) + if err != nil { + promMetrics.ObserveHTTPClientRequest(http.MethodGet, targetURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return OpenIDConfiguration{}, fmt.Errorf("do request: %w", err) + } + defer func() { + if closeBodyErr := resp.Body.Close(); closeBodyErr != nil && logger != nil { + logger.Error(fmt.Sprintf("closing response body error for GET %s", targetURL), log.Error(closeBodyErr)) + } + }() + + if resp.StatusCode != http.StatusOK { + promMetrics.ObserveHTTPClientRequest( + http.MethodGet, targetURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return OpenIDConfiguration{}, fmt.Errorf("unexpected HTTP code %d", resp.StatusCode) + } + + var openIDCfg OpenIDConfiguration + if err = json.NewDecoder(resp.Body).Decode(&openIDCfg); err != nil { + promMetrics.ObserveHTTPClientRequest( + http.MethodGet, targetURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return OpenIDConfiguration{}, fmt.Errorf("decode response body json (Content-Type: %s): %w", + resp.Header.Get("Content-Type"), err) + } + + promMetrics.ObserveHTTPClientRequest(http.MethodGet, targetURL, resp.StatusCode, elapsed, "") + return openIDCfg, nil +} diff --git a/internal/idputil/trusted_issuers_store.go b/internal/idputil/trusted_issuers_store.go new file mode 100644 index 0000000..d7e1fd1 --- /dev/null +++ b/internal/idputil/trusted_issuers_store.go @@ -0,0 +1,80 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idputil + +import ( + "fmt" + "net/url" + "sync" + + "github.com/vasayxtx/go-glob" +) + +type TrustedIssuerURLMatcher func(issURL *url.URL) bool + +type TrustedIssuerStore struct { + mu sync.RWMutex + issuers map[string]string + issuerURLMatchers []TrustedIssuerURLMatcher +} + +func NewTrustedIssuerStore() *TrustedIssuerStore { + return &TrustedIssuerStore{ + issuers: make(map[string]string), + } +} + +func (s *TrustedIssuerStore) AddTrustedIssuer(issName, issURL string) { + s.mu.Lock() + s.issuers[issName] = issURL + s.mu.Unlock() +} + +func (s *TrustedIssuerStore) AddTrustedIssuerURL(issURL string) error { + s.mu.Lock() + defer s.mu.Unlock() + urlMatcher, err := makeTrustedIssuerURLMatcher(issURL) + if err != nil { + return err + } + s.issuerURLMatchers = append(s.issuerURLMatchers, urlMatcher) + return nil +} + +func (s *TrustedIssuerStore) GetURLForIssuer(issuer string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + if issuerURL, ok := s.issuers[issuer]; ok { + return issuerURL, true + } + + parsedIssURL, err := url.Parse(issuer) + if err != nil { + return "", false + } + for i := range s.issuerURLMatchers { + if s.issuerURLMatchers[i](parsedIssURL) { + return issuer, true + } + } + + return "", false +} + +func makeTrustedIssuerURLMatcher(urlPattern string) (TrustedIssuerURLMatcher, error) { + parsedURL, err := url.Parse(urlPattern) + if err != nil { + return nil, fmt.Errorf("parse issuer URL glob pattern: %w", err) + } + hostMatcher := glob.Compile(parsedURL.Host) + return func(issURL *url.URL) bool { + return hostMatcher(issURL.Host) && + parsedURL.Path == issURL.Path && + parsedURL.Scheme == issURL.Scheme && + parsedURL.RawQuery == issURL.RawQuery + }, nil +} diff --git a/internal/libinfo/doc.go b/internal/libinfo/doc.go new file mode 100644 index 0000000..7087edf --- /dev/null +++ b/internal/libinfo/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package libinfo provides helpers for working with the library information. +package libinfo diff --git a/internal/libinfo/lib_info.go b/internal/libinfo/lib_info.go new file mode 100644 index 0000000..b08d088 --- /dev/null +++ b/internal/libinfo/lib_info.go @@ -0,0 +1,26 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package libinfo + +import ( + "fmt" +) + +const libName = "go-authkit" + +const libPath = "github.com/acronis/" + libName + +func MakeUserAgent(prependedUserAgent string) string { + if prependedUserAgent != "" { + prependedUserAgent += " " + } + return prependedUserAgent + " " + libName + "/" + GetLibVersion() +} + +func GetLogPrefix() string { + return fmt.Sprintf("[%s/%s]", libName, GetLibVersion()) +} diff --git a/internal/libinfo/version.go b/internal/libinfo/version.go new file mode 100644 index 0000000..a533b5b --- /dev/null +++ b/internal/libinfo/version.go @@ -0,0 +1,33 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package libinfo + +import ( + "sync" + + "runtime/debug" +) + +var libVersion string +var libVersionOnce sync.Once + +func initLibVersion() { + if buildInfo, ok := debug.ReadBuildInfo(); ok && buildInfo != nil { + for _, dep := range buildInfo.Deps { + if dep.Path == libPath { + libVersion = dep.Version + return + } + } + } + libVersion = "v0.0.0" +} + +func GetLibVersion() string { + libVersionOnce.Do(initLibVersion) + return libVersion +} diff --git a/internal/metrics/doc.go b/internal/metrics/doc.go new file mode 100644 index 0000000..6f9a198 --- /dev/null +++ b/internal/metrics/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package metrics provides helpers for working with the library metrics. +package metrics diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..e7cf519 --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,174 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package metrics + +import ( + "strconv" + "sync" + "time" + + "github.com/acronis/go-appkit/lrucache" + "github.com/prometheus/client_golang/prometheus" + grpccodes "google.golang.org/grpc/codes" + + "github.com/acronis/go-authkit/internal/libinfo" +) + +const PrometheusNamespace = "go_authkit" + +const DefaultPrometheusLibInstanceLabel = "default" + +const ( + PrometheusLibInstanceLabel = "lib_instance" + PrometheusLibSourceLabel = "lib_source" +) + +func PrometheusLabels() prometheus.Labels { + return prometheus.Labels{"lib_version": libinfo.GetLibVersion()} +} + +const ( + HTTPClientRequestLabelMethod = "method" + HTTPClientRequestLabelURL = "url" + HTTPClientRequestLabelStatusCode = "status_code" + HTTPClientRequestLabelError = "error" + + GRPCClientRequestLabelMethod = "grpc_method" + GRPCClientRequestLabelCode = "grpc_code" +) + +const ( + HTTPRequestErrorDo = "do_request_error" + HTTPRequestErrorDecodeBody = "decode_body_error" + HTTPRequestErrorUnexpectedStatusCode = "unexpected_status_code" +) + +var requestDurationBuckets = []float64{0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} + +var ( + prometheusMetrics *PrometheusMetrics + prometheusMetricsOnce sync.Once +) + +// PrometheusMetrics represents the collector of metrics. +type PrometheusMetrics struct { + HTTPClientRequestDuration *prometheus.HistogramVec + GRPCClientRequestDuration *prometheus.HistogramVec + TokenClaimsCache *lrucache.PrometheusMetrics + TokenNegativeCache *lrucache.PrometheusMetrics +} + +func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics { + prometheusMetricsOnce.Do(func() { + prometheusMetrics = newPrometheusMetrics() + prometheusMetrics.MustRegister() + }) + if instance == "" { + instance = DefaultPrometheusLibInstanceLabel + } + return prometheusMetrics.MustCurryWith(map[string]string{ + PrometheusLibInstanceLabel: instance, + PrometheusLibSourceLabel: source, + }) +} + +func newPrometheusMetrics() *PrometheusMetrics { + curriedLabelNames := []string{PrometheusLibInstanceLabel, PrometheusLibSourceLabel} + makeLabelNames := func(names ...string) []string { + l := append(make([]string, 0, len(curriedLabelNames)+len(names)), curriedLabelNames...) + return append(l, names...) + } + + httpClientReqDuration := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: PrometheusNamespace, + Name: "http_client_request_duration_seconds", + Help: "A histogram of the http client request durations to IDP endpoints.", + Buckets: requestDurationBuckets, + ConstLabels: PrometheusLabels(), + }, + makeLabelNames(HTTPClientRequestLabelMethod, HTTPClientRequestLabelURL, + HTTPClientRequestLabelStatusCode, HTTPClientRequestLabelError), + ) + grpcClientReqDuration := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: PrometheusNamespace, + Name: "grpc_client_request_duration_seconds", + Help: "A histogram of the grpc client request durations to IDP endpoints.", + Buckets: requestDurationBuckets, + ConstLabels: PrometheusLabels(), + }, + makeLabelNames(GRPCClientRequestLabelMethod, GRPCClientRequestLabelCode), + ) + + tokenClaimsCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{ + Namespace: PrometheusNamespace + "_token_claims", + ConstLabels: PrometheusLabels(), + CurriedLabelNames: curriedLabelNames, + }) + + tokenNegativeCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{ + Namespace: PrometheusNamespace + "_token_negative", + ConstLabels: PrometheusLabels(), + CurriedLabelNames: curriedLabelNames, + }) + + return &PrometheusMetrics{ + HTTPClientRequestDuration: httpClientReqDuration, + GRPCClientRequestDuration: grpcClientReqDuration, + TokenClaimsCache: tokenClaimsCache, + TokenNegativeCache: tokenNegativeCache, + } +} + +// MustCurryWith curries the metrics collector with the provided labels. +func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *PrometheusMetrics { + return &PrometheusMetrics{ + HTTPClientRequestDuration: pm.HTTPClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), + GRPCClientRequestDuration: pm.GRPCClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), + TokenClaimsCache: pm.TokenClaimsCache.MustCurryWith(labels), + TokenNegativeCache: pm.TokenNegativeCache.MustCurryWith(labels), + } +} + +// MustRegister does registration of metrics collector in Prometheus and panics if any error occurs. +func (pm *PrometheusMetrics) MustRegister() { + prometheus.MustRegister( + pm.HTTPClientRequestDuration, + pm.GRPCClientRequestDuration, + ) + pm.TokenClaimsCache.MustRegister() + pm.TokenNegativeCache.MustRegister() +} + +// Unregister cancels registration of metrics collector in Prometheus. +func (pm *PrometheusMetrics) Unregister() { + prometheus.Unregister(pm.HTTPClientRequestDuration) + prometheus.Unregister(pm.GRPCClientRequestDuration) + pm.TokenClaimsCache.Unregister() + pm.TokenNegativeCache.Unregister() +} + +func (pm *PrometheusMetrics) ObserveHTTPClientRequest( + method string, targetURL string, statusCode int, elapsed time.Duration, errorType string, +) { + pm.HTTPClientRequestDuration.With(prometheus.Labels{ + HTTPClientRequestLabelMethod: method, + HTTPClientRequestLabelURL: targetURL, + HTTPClientRequestLabelStatusCode: strconv.Itoa(statusCode), + HTTPClientRequestLabelError: errorType, + }).Observe(elapsed.Seconds()) +} + +func (pm *PrometheusMetrics) ObserveGRPCClientRequest( + method string, code grpccodes.Code, elapsed time.Duration, +) { + pm.GRPCClientRequestDuration.With(prometheus.Labels{ + GRPCClientRequestLabelMethod: method, + GRPCClientRequestLabelCode: code.String(), + }).Observe(elapsed.Seconds()) +} diff --git a/internal/testing/doc.go b/internal/testing/doc.go new file mode 100644 index 0000000..a516b74 --- /dev/null +++ b/internal/testing/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package testing provides internal testing utilities. +package testing diff --git a/internal/testing/server_token_introspector_mock.go b/internal/testing/server_token_introspector_mock.go new file mode 100644 index 0000000..9f4b247 --- /dev/null +++ b/internal/testing/server_token_introspector_mock.go @@ -0,0 +1,152 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package testing + +import ( + "context" + "crypto/sha256" + "net/http" + "net/url" + + "google.golang.org/grpc/metadata" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/jwt" +) + +type JWTParser interface { + Parse(ctx context.Context, token string) (*jwt.Claims, error) +} + +type HTTPServerTokenIntrospectorMock struct { + JWTParser JWTParser + + introspectionResults map[[sha256.Size]byte]idptoken.IntrospectionResult + jwtScopes map[string][]jwt.AccessPolicy + + Called bool + LastAuthorizationHeader string + LastIntrospectedToken string + LastFormValues url.Values +} + +func NewHTTPServerTokenIntrospectorMock() *HTTPServerTokenIntrospectorMock { + return &HTTPServerTokenIntrospectorMock{ + introspectionResults: make(map[[sha256.Size]byte]idptoken.IntrospectionResult), + jwtScopes: make(map[string][]jwt.AccessPolicy), + } +} + +func (m *HTTPServerTokenIntrospectorMock) SetResultForToken(token string, result idptoken.IntrospectionResult) { + m.introspectionResults[tokenToKey(token)] = result +} + +func (m *HTTPServerTokenIntrospectorMock) SetScopeForJWTID(jwtID string, scope []jwt.AccessPolicy) { + m.jwtScopes[jwtID] = scope +} + +func (m *HTTPServerTokenIntrospectorMock) IntrospectToken(r *http.Request, token string) idptoken.IntrospectionResult { + m.Called = true + m.LastAuthorizationHeader = r.Header.Get("Authorization") + m.LastIntrospectedToken = token + m.LastFormValues = r.Form + + if result, ok := m.introspectionResults[tokenToKey(token)]; ok { + return result + } + + claims, err := m.JWTParser.Parse(r.Context(), token) + if err != nil { + return idptoken.IntrospectionResult{Active: false} + } + result := idptoken.IntrospectionResult{Active: true, TokenType: idptoken.TokenTypeBearer, Claims: *claims} + if scopes, ok := m.jwtScopes[claims.ID]; ok { + result.Scope = scopes + } + return result +} + +func (m *HTTPServerTokenIntrospectorMock) ResetCallsInfo() { + m.Called = false + m.LastAuthorizationHeader = "" + m.LastIntrospectedToken = "" + m.LastFormValues = nil +} + +type GRPCServerTokenIntrospectorMock struct { + JWTParser JWTParser + + introspectionResults map[[sha256.Size]byte]*pb.IntrospectTokenResponse + scopes map[string][]*pb.AccessTokenScope + + Called bool + LastAuthorizationMeta string + LastRequest *pb.IntrospectTokenRequest +} + +func NewGRPCServerTokenIntrospectorMock() *GRPCServerTokenIntrospectorMock { + return &GRPCServerTokenIntrospectorMock{ + introspectionResults: make(map[[sha256.Size]byte]*pb.IntrospectTokenResponse), + scopes: make(map[string][]*pb.AccessTokenScope), + } +} + +func (m *GRPCServerTokenIntrospectorMock) SetResultForToken(token string, result *pb.IntrospectTokenResponse) { + m.introspectionResults[tokenToKey(token)] = result +} + +func (m *GRPCServerTokenIntrospectorMock) SetScopeForJWTID(jwtID string, scope []*pb.AccessTokenScope) { + m.scopes[jwtID] = scope +} + +func (m *GRPCServerTokenIntrospectorMock) IntrospectToken( + ctx context.Context, req *pb.IntrospectTokenRequest, +) (*pb.IntrospectTokenResponse, error) { + m.Called = true + if mdVal := metadata.ValueFromIncomingContext(ctx, "authorization"); len(mdVal) != 0 { + m.LastAuthorizationMeta = mdVal[0] + } else { + m.LastAuthorizationMeta = "" + } + m.LastRequest = req + + if result, ok := m.introspectionResults[tokenToKey(req.Token)]; ok { + return result, nil + } + + claims, err := m.JWTParser.Parse(ctx, req.Token) + if err != nil { + return &pb.IntrospectTokenResponse{Active: false}, nil + } + result := &pb.IntrospectTokenResponse{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Exp: claims.ExpiresAt.Unix(), + Aud: claims.Audience, + Jti: claims.ID, + Iss: claims.Issuer, + Sub: claims.Subject, + SubType: claims.SubType, + ClientId: claims.ClientID, + OwnerTenantUuid: claims.OwnerTenantUUID, + } + if scopes, ok := m.scopes[claims.ID]; ok { + result.Scope = scopes + } + return result, nil +} + +func (m *GRPCServerTokenIntrospectorMock) ResetCallsInfo() { + m.Called = false + m.LastAuthorizationMeta = "" + m.LastRequest = nil +} + +func tokenToKey(token string) [sha256.Size]byte { + return sha256.Sum256([]byte(token)) +} diff --git a/jwks/caching_client.go b/jwks/caching_client.go new file mode 100644 index 0000000..6954618 --- /dev/null +++ b/jwks/caching_client.go @@ -0,0 +1,169 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/lrucache" +) + +const DefaultCacheUpdateMinInterval = time.Minute * 1 + +// CachingClientOpts contains options for CachingClient. +type CachingClientOpts struct { + ClientOpts + + // CacheUpdateMinInterval is a minimal interval between cache updates for the same issuer. + CacheUpdateMinInterval time.Duration +} + +// CachingClient is a Client for getting keys from remote JWKS with a caching mechanism. +type CachingClient struct { + mu sync.RWMutex + rawClient *Client + issuerCache map[string]issuerCacheEntry + cacheUpdateMinInterval time.Duration +} + +const missingKeysCacheSize = 100 + +type issuerCacheEntry struct { + updatedAt time.Time + keys map[string]interface{} + missingKeys *lrucache.LRUCache[string, time.Time] +} + +// NewCachingClient returns a new Client that can cache fetched data. +func NewCachingClient(httpClient *http.Client, logger log.FieldLogger) *CachingClient { + return NewCachingClientWithOpts(httpClient, logger, CachingClientOpts{}) +} + +// NewCachingClientWithOpts returns a new Client that can cache fetched data with options. +func NewCachingClientWithOpts(httpClient *http.Client, logger log.FieldLogger, opts CachingClientOpts) *CachingClient { + if opts.CacheUpdateMinInterval == 0 { + opts.CacheUpdateMinInterval = DefaultCacheUpdateMinInterval + } + return &CachingClient{ + rawClient: NewClientWithOpts(httpClient, logger, opts.ClientOpts), + issuerCache: make(map[string]issuerCacheEntry), + cacheUpdateMinInterval: opts.CacheUpdateMinInterval, + } +} + +// GetRSAPublicKey searches JWK with passed key ID in JWKS and returns decoded RSA public key for it. +// The last one can be used for verifying JWT signature. Obtained JWKS is cached. +// If passed issuer URL or key ID is not found in the cache, JWKS will be fetched again, +// but not more than once in a some (configurable) period of time. +func (cc *CachingClient) GetRSAPublicKey(ctx context.Context, issuerURL, keyID string) (interface{}, error) { + pubKey, found, needInvalidate := cc.getPubKeyFromCache(issuerURL, keyID) + if found { + return pubKey, nil + } + if needInvalidate { + var err error + if pubKey, found, err = cc.getPubKeyFromCacheAndInvalidate(ctx, issuerURL, keyID); err != nil || found { + return pubKey, err + } + } + return nil, &JWKNotFoundError{IssuerURL: issuerURL, KeyID: keyID} +} + +// InvalidateCacheIfNeeded does cache invalidation for specific issuer URL if it's necessary. +func (cc *CachingClient) InvalidateCacheIfNeeded(ctx context.Context, issuerURL string) error { + cc.mu.Lock() + defer cc.mu.Unlock() + + var missingKeys *lrucache.LRUCache[string, time.Time] + issCache, found := cc.issuerCache[issuerURL] + if found { + if time.Since(issCache.updatedAt) < cc.cacheUpdateMinInterval { + return nil + } + missingKeys = issCache.missingKeys + } else { + var err error + if missingKeys, err = lrucache.New[string, time.Time](missingKeysCacheSize, nil); err != nil { + return fmt.Errorf("new lru cache for missing keys: %w", err) + } + } + + pubKeys, err := cc.rawClient.getRSAPubKeysForIssuer(ctx, issuerURL) + if err != nil { + return fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err) + } + cc.issuerCache[issuerURL] = issuerCacheEntry{ + updatedAt: time.Now(), + keys: pubKeys, + missingKeys: missingKeys, + } + return nil +} + +func (cc *CachingClient) getPubKeyFromCache( + issuerURL, keyID string, +) (pubKey interface{}, found bool, needInvalidate bool) { + cc.mu.RLock() + defer cc.mu.RUnlock() + + issCache, issFound := cc.issuerCache[issuerURL] + if !issFound { + return nil, false, true + } + if pubKey, found = issCache.keys[keyID]; found { + return + } + missedAt, miss := issCache.missingKeys.Get(keyID) + if !miss || time.Since(missedAt) > cc.cacheUpdateMinInterval { + return nil, false, true + } + return nil, false, false +} + +func (cc *CachingClient) getPubKeyFromCacheAndInvalidate( + ctx context.Context, issuerURL, keyID string, +) (pubKey interface{}, found bool, err error) { + cc.mu.Lock() + defer cc.mu.Unlock() + + var missingKeys *lrucache.LRUCache[string, time.Time] + if issCache, issFound := cc.issuerCache[issuerURL]; issFound { + if pubKey, found = issCache.keys[keyID]; found { + return pubKey, true, nil + } + missedAt, miss := issCache.missingKeys.Get(keyID) + if miss && time.Since(missedAt) < cc.cacheUpdateMinInterval { + return nil, false, nil + } + missingKeys = issCache.missingKeys + } else { + missingKeys, err = lrucache.New[string, time.Time](missingKeysCacheSize, nil) + if err != nil { + return nil, false, fmt.Errorf("new lru cache for missing keys: %w", err) + } + } + + pubKeys, err := cc.rawClient.getRSAPubKeysForIssuer(ctx, issuerURL) + if err != nil { + return nil, false, fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err) + } + pubKey, found = pubKeys[keyID] + if !found { + missingKeys.Add(keyID, time.Now()) + } + cc.issuerCache[issuerURL] = issuerCacheEntry{ + updatedAt: time.Now(), + keys: pubKeys, + missingKeys: missingKeys, + } + return pubKey, found, nil +} diff --git a/jwks/caching_client_test.go b/jwks/caching_client_test.go new file mode 100644 index 0000000..395bbba --- /dev/null +++ b/jwks/caching_client_test.go @@ -0,0 +1,110 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks_test + +import ( + "context" + "crypto/rsa" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" +) + +func TestCachingClient_GetRSAPublicKey(t *testing.T) { + t.Run("ok", func(t *testing.T) { + jwksHandler := &idptest.JWKSHandler{} + jwksServer := httptest.NewServer(jwksHandler) + defer jwksServer.Close() + issuerConfigHandler := &idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL} + issuerConfigServer := httptest.NewServer(issuerConfigHandler) + defer issuerConfigServer.Close() + + cachingClient := jwks.NewCachingClientWithOpts(http.DefaultClient, log.NewDisabledLogger(), + jwks.CachingClientOpts{CacheUpdateMinInterval: time.Second * 10}) + var wg sync.WaitGroup + const callsNum = 10 + wg.Add(callsNum) + errs := make(chan error, callsNum) + pubKeys := make(chan interface{}, callsNum) + for i := 0; i < callsNum; i++ { + go func() { + defer wg.Done() + pubKey, err := cachingClient.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + if err != nil { + errs <- err + return + } + pubKeys <- pubKey + }() + } + wg.Wait() + close(errs) + close(pubKeys) + for err := range errs { + require.NoError(t, err) + } + for pubKey := range pubKeys { + require.NotNil(t, pubKey) + require.IsType(t, &rsa.PublicKey{}, pubKey) + } + require.EqualValues(t, 1, issuerConfigHandler.ServedCount()) + require.EqualValues(t, 1, jwksHandler.ServedCount()) + }) + + t.Run("jwk not found", func(t *testing.T) { + jwksHandler := &idptest.JWKSHandler{} + jwksServer := httptest.NewServer(jwksHandler) + defer jwksServer.Close() + issuerConfigHandler := &idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL} + issuerConfigServer := httptest.NewServer(issuerConfigHandler) + defer issuerConfigServer.Close() + + const unknownKeyID = "77777777-7777-7777-7777-777777777777" + const cacheUpdateMinInterval = time.Second * 1 + + cachingClient := jwks.NewCachingClientWithOpts(http.DefaultClient, log.NewDisabledLogger(), + jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) + + doGetPublicKeyByUnknownID := func(callsNum int) { + t.Helper() + var wg sync.WaitGroup + wg.Add(callsNum) + for i := 0; i < callsNum; i++ { + go func() { + defer wg.Done() + pubKey, err := cachingClient.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, unknownKeyID) + require.Error(t, err) + var jwkErr *jwks.JWKNotFoundError + require.True(t, errors.As(err, &jwkErr)) + require.Equal(t, issuerConfigServer.URL, jwkErr.IssuerURL) + require.Equal(t, unknownKeyID, jwkErr.KeyID) + require.Nil(t, pubKey) + }() + } + wg.Wait() + } + + doGetPublicKeyByUnknownID(10) + require.EqualValues(t, 1, issuerConfigHandler.ServedCount()) + require.EqualValues(t, 1, jwksHandler.ServedCount()) + + time.Sleep(cacheUpdateMinInterval * 2) + + doGetPublicKeyByUnknownID(10) + require.EqualValues(t, 2, issuerConfigHandler.ServedCount()) + require.EqualValues(t, 2, jwksHandler.ServedCount()) + }) +} diff --git a/jwks/client.go b/jwks/client.go new file mode 100644 index 0000000..9af6075 --- /dev/null +++ b/jwks/client.go @@ -0,0 +1,138 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks + +import ( + "context" + "crypto" + "crypto/rsa" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/acronis/go-appkit/log" + "github.com/mendsley/gojwk" + + "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" +) + +const OpenIDConfigurationPath = "/.well-known/openid-configuration" + +type jwksData struct { + Keys []*gojwk.Key `json:"keys"` +} + +// ClientOpts contains options for the JWKS client. +type ClientOpts struct { + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same library. + PrometheusLibInstanceLabel string +} + +// Client gets public keys from remote JWKS. +// It uses jwks_uri field from /.well-known/openid-configuration endpoint. +// NOTE: CachingClient should be used in a typical service +// to avoid making HTTP requests on each JWT verification. +type Client struct { + httpClient *http.Client + logger log.FieldLogger + promMetrics *metrics.PrometheusMetrics +} + +// NewClient returns a new Client. +func NewClient(httpClient *http.Client, logger log.FieldLogger) *Client { + return NewClientWithOpts(httpClient, logger, ClientOpts{}) +} + +// NewClientWithOpts returns a new Client with options. +func NewClientWithOpts(httpClient *http.Client, logger log.FieldLogger, opts ClientOpts) *Client { + promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client") + return &Client{httpClient, logger, promMetrics} +} + +func (c *Client) getRSAPubKeysForIssuer(ctx context.Context, issuerURL string) (map[string]interface{}, error) { + openIDConfigURL := strings.TrimPrefix(issuerURL, "/") + OpenIDConfigurationPath + openIDConfig, err := idputil.GetOpenIDConfiguration( + ctx, c.httpClient, openIDConfigURL, nil, c.logger, c.promMetrics) + if err != nil { + return nil, &GetOpenIDConfigurationError{Inner: err, URL: openIDConfigURL} + } + jwksRespData, err := c.getJWKS(ctx, openIDConfig.JWKSURI) + if err != nil { + return nil, &GetJWKSError{Inner: err, URL: openIDConfig.JWKSURI, OpenIDConfigurationURL: openIDConfigURL} + } + c.logger.Info(fmt.Sprintf("%d keys fetched (jwks_url: %s)", len(jwksRespData.Keys), openIDConfig.JWKSURI)) + + pubKeys := make(map[string]interface{}, len(jwksRespData.Keys)) + for _, jwk := range jwksRespData.Keys { + var pubKey crypto.PublicKey + if pubKey, err = jwk.DecodePublicKey(); err != nil { + c.logger.Error(fmt.Sprintf("decoding JWK (kid: %s, jwks_url: %s) to public key error", + jwk.Kid, openIDConfig.JWKSURI), log.Error(err)) + continue + } + rsaPubKey, ok := pubKey.(*rsa.PublicKey) + if !ok { + c.logger.Error(fmt.Sprintf("converting JWK (kid: %s, jwks_url: %s) to RSA public key error", + jwk.Kid, openIDConfig.JWKSURI), log.Error(err)) + continue + } + pubKeys[jwk.Kid] = rsaPubKey + } + return pubKeys, nil +} + +// GetRSAPublicKey gets JWK from JWKS and returns decoded RSA public key. The last one can be used for verifying JWT signature. +func (c *Client) GetRSAPublicKey(ctx context.Context, issuerURL, keyID string) (interface{}, error) { + pubKeys, err := c.getRSAPubKeysForIssuer(ctx, issuerURL) + if err != nil { + return nil, fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err) + } + pubKey, ok := pubKeys[keyID] + if !ok { + return nil, &JWKNotFoundError{IssuerURL: issuerURL, KeyID: keyID} + } + return pubKey, nil +} + +func (c *Client) getJWKS(ctx context.Context, jwksURL string) (jwksData, error) { + req, err := http.NewRequest(http.MethodGet, jwksURL, http.NoBody) + if err != nil { + return jwksData{}, fmt.Errorf("new request: %w", err) + } + startTime := time.Now() + resp, err := c.httpClient.Do(req.WithContext(ctx)) + elapsed := time.Since(startTime) + if err != nil { + c.promMetrics.ObserveHTTPClientRequest(http.MethodGet, jwksURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return jwksData{}, fmt.Errorf("do request: %w", err) + } + defer func() { + if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { + c.logger.Error(fmt.Sprintf("closing response body error for GET %s", jwksURL), log.Error(closeBodyErr)) + } + }() + + if resp.StatusCode != http.StatusOK { + c.promMetrics.ObserveHTTPClientRequest( + http.MethodGet, jwksURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return jwksData{}, fmt.Errorf("unexpected HTTP code %d", resp.StatusCode) + } + + var res jwksData + if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { + c.promMetrics.ObserveHTTPClientRequest( + http.MethodGet, jwksURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return jwksData{}, fmt.Errorf("decode response body json: %w", err) + } + + c.promMetrics.ObserveHTTPClientRequest(http.MethodGet, jwksURL, resp.StatusCode, elapsed, "") + return res, nil +} diff --git a/jwks/client_test.go b/jwks/client_test.go new file mode 100644 index 0000000..96e3525 --- /dev/null +++ b/jwks/client_test.go @@ -0,0 +1,160 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks_test + +import ( + "context" + "crypto/rsa" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/acronis/go-appkit/log" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" +) + +func TestClient_GetRSAPublicKey(t *testing.T) { + t.Run("ok", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.NoError(t, err) + require.NotNil(t, pubKey) + require.IsType(t, &rsa.PublicKey{}, pubKey) + }) + + t.Run("issuer openid configuration unavailable", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + issuerConfigServer.Close() // Close the server immediately. + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.ErrorContains(t, openIDCfgErr.Inner, "connection refused") + require.Nil(t, pubKey) + }) + + t.Run("openid configuration server respond internal error", func(t *testing.T) { + issuerConfigServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + })) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.EqualError(t, openIDCfgErr.Inner, "unexpected HTTP code 500") + require.Nil(t, pubKey) + }) + + t.Run("openid configuration server respond invalid json", func(t *testing.T) { + issuerConfigServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, err := rw.Write([]byte(`{"invalid-json"]`)) + require.NoError(t, err) + })) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + var jsonSyntaxErr *json.SyntaxError + require.True(t, errors.As(openIDCfgErr, &jsonSyntaxErr)) + require.Nil(t, pubKey) + }) + + t.Run("jwks server unavailable", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + jwksServer.Close() // Close the server immediately. + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var jwksErr *jwks.GetJWKSError + require.True(t, errors.As(err, &jwksErr)) + require.Equal(t, jwksServer.URL, jwksErr.URL) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) + require.ErrorContains(t, jwksErr.Inner, "connection refused") + require.Nil(t, pubKey) + }) + + t.Run("jwks server respond internal error", func(t *testing.T) { + jwksServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + })) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var jwksErr *jwks.GetJWKSError + require.True(t, errors.As(err, &jwksErr)) + require.Equal(t, jwksServer.URL, jwksErr.URL) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) + require.EqualError(t, jwksErr.Inner, "unexpected HTTP code 500") + require.Nil(t, pubKey) + }) + + t.Run("jwk not found", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + const unknownKeyID = "77777777-7777-7777-7777-777777777777" + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, unknownKeyID) + require.Error(t, err) + var jwkErr *jwks.JWKNotFoundError + require.True(t, errors.As(err, &jwkErr)) + require.Equal(t, issuerConfigServer.URL, jwkErr.IssuerURL) + require.Equal(t, unknownKeyID, jwkErr.KeyID) + require.Nil(t, pubKey) + }) + + t.Run("context canceled", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + ctx, cancelCtxFn := context.WithCancel(context.Background()) + cancelCtxFn() // Emulate canceling context. + pubKey, err := client.GetRSAPublicKey(ctx, issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.ErrorIs(t, openIDCfgErr, context.Canceled) + require.Nil(t, pubKey) + }) +} diff --git a/jwks/doc.go b/jwks/doc.go new file mode 100644 index 0000000..56650b7 --- /dev/null +++ b/jwks/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package jwks contains clients for getting public keys from JWKS. +package jwks diff --git a/jwks/errors.go b/jwks/errors.go new file mode 100644 index 0000000..bfe9e45 --- /dev/null +++ b/jwks/errors.go @@ -0,0 +1,49 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks + +import "fmt" + +// GetOpenIDConfigurationError is an error that may occur during getting openID configuration for issuer. +type GetOpenIDConfigurationError struct { + Inner error + URL string +} + +func (e *GetOpenIDConfigurationError) Error() string { + return fmt.Sprintf("error while getting OpenID configuration (URL: %q): %s", e.URL, e.Inner.Error()) +} + +func (e *GetOpenIDConfigurationError) Unwrap() error { + return e.Inner +} + +// GetJWKSError is an error that may occur during getting JWKS. +type GetJWKSError struct { + Inner error + URL string + OpenIDConfigurationURL string +} + +func (e *GetJWKSError) Error() string { + return fmt.Sprintf("error while getting JWKS data (URL: %q, OpenID configuration URL: %q): %s", + e.URL, e.OpenIDConfigurationURL, e.Inner.Error()) +} + +func (e *GetJWKSError) Unwrap() error { + return e.Inner +} + +// JWKNotFoundError is an error that occurs when JWK is not found by kid. +type JWKNotFoundError struct { + IssuerURL string + KeyID string +} + +func (e *JWKNotFoundError) Error() string { + return fmt.Sprintf("JWK not found (Key ID: %q, Issuer URL: %q)", e.KeyID, e.IssuerURL) +} diff --git a/jwt/caching_parser.go b/jwt/caching_parser.go new file mode 100644 index 0000000..f0cfdc0 --- /dev/null +++ b/jwt/caching_parser.go @@ -0,0 +1,113 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import ( + "context" + "crypto/sha256" + "fmt" + "unsafe" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/lrucache" + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/acronis/go-authkit/internal/metrics" +) + +const DefaultClaimsCacheMaxEntries = 1000 + +type CachingParserOpts struct { + ParserOpts + CacheMaxEntries int + CachePrometheusInstanceLabel string +} + +// ClaimsCache is an interface that must be implemented by used cache implementations. +type ClaimsCache interface { + Get(key [sha256.Size]byte) (*Claims, bool) + Add(key [sha256.Size]byte, value *Claims) + Purge() + Len() int +} + +// CachingParser uses the functionality of Parser to parse JWT, but stores resulted Claims objects in the cache. +type CachingParser struct { + *Parser + ClaimsCache ClaimsCache +} + +func NewCachingParser(keysProvider KeysProvider, logger log.FieldLogger) (*CachingParser, error) { + return NewCachingParserWithOpts(keysProvider, logger, CachingParserOpts{}) +} + +func NewCachingParserWithOpts( + keysProvider KeysProvider, logger log.FieldLogger, opts CachingParserOpts, +) (*CachingParser, error) { + promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, "jwt_parser") + if opts.CacheMaxEntries == 0 { + opts.CacheMaxEntries = DefaultClaimsCacheMaxEntries + } + cache, err := lrucache.New[[sha256.Size]byte, *Claims](opts.CacheMaxEntries, promMetrics.TokenClaimsCache) + if err != nil { + return nil, err + } + return &CachingParser{ + Parser: NewParserWithOpts(keysProvider, logger, opts.ParserOpts), + ClaimsCache: cache, + }, nil +} + +// getTokenHash converts an access token to a string hash that is used as a cache key. +func getTokenHash(token []byte) [sha256.Size]byte { + return sha256.Sum256(token) +} + +// stringToBytesUnsafe converts string to byte slice without memory allocation. (both heap and stack) +func stringToBytesUnsafe(s string) []byte { + // nolint: gosec // memory optimization to prevent redundant slice copying + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// Parse calls Parse method of embedded original Parser but stores result into cache. +func (cp *CachingParser) Parse(ctx context.Context, token string) (*Claims, error) { + key := getTokenHash(stringToBytesUnsafe(token)) + cachedClaims, foundInCache, validationErr := cp.getFromCacheAndValidateIfNeeded(key) + if foundInCache { + if validationErr != nil { + return nil, validationErr + } + return cachedClaims, nil + } + claims, err := cp.Parser.Parse(ctx, token) + if err != nil { + return nil, err + } + cp.ClaimsCache.Add(key, claims) + return claims, nil +} + +func (cp *CachingParser) getFromCacheAndValidateIfNeeded(key [sha256.Size]byte) (claims *Claims, found bool, err error) { + cachedClaims, found := cp.ClaimsCache.Get(key) + if !found { + return nil, false, nil + } + if !cp.Parser.skipClaimsValidation { + if err = cp.Parser.claimsValidator.Validate(cachedClaims); err != nil { + return nil, true, fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidClaims, err) + } + if err = cp.Parser.customValidator(cachedClaims); err != nil { + return nil, true, fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidClaims, err) + } + } + return cachedClaims, true, nil +} + +// InvalidateClaimsCache removes all preserved parsed Claims objects from cache. +func (cp *CachingParser) InvalidateClaimsCache() { + cp.ClaimsCache.Purge() +} diff --git a/jwt/caching_parser_test.go b/jwt/caching_parser_test.go new file mode 100644 index 0000000..f0ede58 --- /dev/null +++ b/jwt/caching_parser_test.go @@ -0,0 +1,109 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt_test + +import ( + "context" + "crypto/sha256" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +func getTokenHash(token []byte) [sha256.Size]byte { + tokenCheckSum := sha256.Sum256(token) + return tokenCheckSum +} + +func TestGetTokenHash(t *testing.T) { + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString := []byte(idptest.MustMakeTokenStringSignedWithTestKey(claims)) + + th := getTokenHash(tokenString) + require.NotEmpty(t, th, "generated token hash must not be an empty string") + th2 := getTokenHash(tokenString) + require.Equal(t, th, th2, "two hashes of the same token must be equal") + + claims2 := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: "other" + testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(12 * time.Minute))}} + tokenString2 := []byte(idptest.MustMakeTokenStringSignedWithTestKey(claims2)) + th3 := getTokenHash(tokenString2) + require.NotEqual(t, th, th3, "two hashes of different tokens must be different") +} + +func TestCachingParser_Parse(t *testing.T) { + logger := log.NewDisabledLogger() + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + parser, err := jwt.NewCachingParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, err) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + + var parsedClaims *jwt.Claims + parsedClaims, err = parser.Parse(context.Background(), tokenString) + require.NoError(t, err, "caching parser must not return error from Parse method") + require.Equal(t, claims.Scope, parsedClaims.Scope, "unexpected claims value produced by caching parser") + + require.Equal(t, 1, parser.ClaimsCache.Len(), + "one claims object must be cached after successful parse operation") + + tokenKey := getTokenHash([]byte(tokenString)) + cachedClaims, found := parser.ClaimsCache.Get(tokenKey) + require.True(t, found, "cached claims object must be found by token hash") + require.Equal(t, claims.Scope, cachedClaims.Scope, "unexpected claims value fetched from parser cache") + + parser.InvalidateClaimsCache() + require.Equal(t, 0, parser.ClaimsCache.Len(), + "parser cache must be empty after invalidation") +} + +func TestCachingParser_CheckExpiration(t *testing.T) { + const jwtTTL = 2 * time.Second + + logger := log.NewDisabledLogger() + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(jwtTTL))}} + tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + parser, err := jwt.NewCachingParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, err) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + + var parsedClaims *jwt.Claims + parsedClaims, err = parser.Parse(context.Background(), tokenString) + require.NoError(t, err, "caching parser must not return error from Parse method") + require.Equal(t, claims.Scope, parsedClaims.Scope, "unexpected claims value produced by caching parser") + + require.Equal(t, 1, parser.ClaimsCache.Len(), + "one claims object must be cached after successful parse operation") + + time.Sleep(jwtTTL * 2) + + parsedClaims, err = parser.Parse(context.Background(), tokenString) + require.Error(t, err, "caching parser must return error since cached jwt is expired") + require.Nil(t, parsedClaims) +} diff --git a/jwt/doc.go b/jwt/doc.go new file mode 100644 index 0000000..4280943 --- /dev/null +++ b/jwt/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package jwt provides primitives for working with JWT (Parser, Claims, and so on). +package jwt diff --git a/jwt/errors.go b/jwt/errors.go new file mode 100644 index 0000000..fb3d89d --- /dev/null +++ b/jwt/errors.go @@ -0,0 +1,56 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import ( + "fmt" +) + +// SignAlgUnknownError represents an error when JWT signing algorithm is unknown. +type SignAlgUnknownError struct { + Alg string +} + +func (e *SignAlgUnknownError) Error() string { + return fmt.Sprintf("JWT has unknown signing algorithm %q", e.Alg) +} + +// IssuerUntrustedError represents an error when JWT issuer is untrusted. +type IssuerUntrustedError struct { + Claims *Claims +} + +func (e *IssuerUntrustedError) Error() string { + return fmt.Sprintf("JWT issuer %q untrusted", e.Claims.Issuer) +} + +// IssuerMissingError represents an error when JWT issuer is missing. +type IssuerMissingError struct { + Claims *Claims +} + +func (e *IssuerMissingError) Error() string { + return "JWT issuer missing" +} + +// AudienceMissingError represents an error when JWT audience is missing, but it's required. +type AudienceMissingError struct { + Claims *Claims +} + +func (e *AudienceMissingError) Error() string { + return "JWT audience missing" +} + +// AudienceNotExpectedError represents an error when JWT contains not expected audience. +type AudienceNotExpectedError struct { + Claims *Claims +} + +func (e *AudienceNotExpectedError) Error() string { + return fmt.Sprintf("JWT audience %q not expected", e.Claims.Audience) +} diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 0000000..f1be148 --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,255 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import ( + "context" + "errors" + "fmt" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/vasayxtx/go-glob" + + "github.com/acronis/go-authkit/internal/idputil" +) + +// KeysProvider is an interface for providing keys for verifying JWT. +type KeysProvider interface { + GetRSAPublicKey(ctx context.Context, issuer, keyID string) (interface{}, error) +} + +// CachingKeysProvider is an interface for providing keys for verifying JWT. +// Unlike KeysProvider, it supports caching of obtained keys. +type CachingKeysProvider interface { + KeysProvider + InvalidateCacheIfNeeded(ctx context.Context, issuer string) error +} + +// ParserOpts additional options for parser. +type ParserOpts struct { + SkipClaimsValidation bool + RequireAudience bool + ExpectedAudience []string + TrustedIssuerNotFoundFallback TrustedIssNotFoundFallback +} + +type audienceMatcher func(aud string) bool + +// TrustedIssNotFoundFallback is a function called when given issuer is not found in the list of trusted ones. +// For example, it could be analyzed and then added to the list by calling AddTrustedIssuerURL method. +type TrustedIssNotFoundFallback func(ctx context.Context, p *Parser, iss string) (issURL string, issFound bool) + +// Parser is an object for parsing, validation and verification JWT. +type Parser struct { + parser *jwtgo.Parser + claimsValidator *jwtgo.Validator + customValidator func(claims *Claims) error + skipClaimsValidation bool + keysProvider KeysProvider + + trustedIssuerStore *idputil.TrustedIssuerStore + trustedIssuerNotFoundFallback TrustedIssNotFoundFallback + + logger log.FieldLogger +} + +// NewParser creates new JWT parser with specified keys provider. +func NewParser(keysProvider KeysProvider, logger log.FieldLogger) *Parser { + return NewParserWithOpts(keysProvider, logger, ParserOpts{}) +} + +// NewParserWithOpts creates new JWT parser with specified keys provider and additional options. +func NewParserWithOpts(keysProvider KeysProvider, logger log.FieldLogger, opts ParserOpts) *Parser { + var audienceMatchers []audienceMatcher + for _, audPattern := range opts.ExpectedAudience { + audienceMatchers = append(audienceMatchers, glob.Compile(audPattern)) + } + return &Parser{ + parser: jwtgo.NewParser(jwtgo.WithExpirationRequired()), + claimsValidator: jwtgo.NewValidator(jwtgo.WithExpirationRequired()), + customValidator: makeCustomAudienceValidator(opts.RequireAudience, audienceMatchers), + skipClaimsValidation: opts.SkipClaimsValidation, + keysProvider: keysProvider, + trustedIssuerStore: idputil.NewTrustedIssuerStore(), + trustedIssuerNotFoundFallback: opts.TrustedIssuerNotFoundFallback, + logger: logger, + } +} + +// AddTrustedIssuer adds trusted issuer with specified name and URL. +func (p *Parser) AddTrustedIssuer(issName, issURL string) { + p.trustedIssuerStore.AddTrustedIssuer(issName, issURL) +} + +// AddTrustedIssuerURL adds trusted issuer URL. +func (p *Parser) AddTrustedIssuerURL(issURL string) error { + return p.trustedIssuerStore.AddTrustedIssuerURL(issURL) +} + +// GetURLForIssuer returns URL for issuer if it is trusted. +func (p *Parser) GetURLForIssuer(issuer string) (string, bool) { + return p.trustedIssuerStore.GetURLForIssuer(issuer) +} + +// Parse parses, validates and verifies passed token (it's string representation). Parsed claims is returned. +func (p *Parser) Parse(ctx context.Context, token string) (*Claims, error) { + keyFunc := p.getKeyFunc(ctx) + claims := validatableClaims{customValidator: p.customValidator} + if _, err := p.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + if !errors.Is(err, jwtgo.ErrTokenSignatureInvalid) { + return nil, err + } + + // If keys provider supports caching, we may try to invalidate it and try parsing JWT again. + cachingKeysProvider, ok := p.keysProvider.(CachingKeysProvider) + if !ok { + return nil, err + } + + issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, claims.Issuer) + if !issuerURLFound { + return nil, err + } + if err = cachingKeysProvider.InvalidateCacheIfNeeded(ctx, issuerURL); err != nil { + p.logger.Error(fmt.Sprintf("keys provider invalidating cache error for issuer %q", issuerURL), + log.Error(err)) + return nil, err + } + + if _, err = p.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + return nil, err + } + } + + return &claims.Claims, nil +} + +func (p *Parser) getKeyFunc(ctx context.Context) func(token *jwtgo.Token) (interface{}, error) { + return func(token *jwtgo.Token) (i interface{}, err error) { + switch signAlg := token.Method.Alg(); signAlg { + case "none": //nolint:goconst + return nil, jwtgo.NoneSignatureTypeDisallowedError + + case "RS256", "RS384", "RS512": + // Empty kid is LEGAL, not all IDP impl support kid. + kidStr := "" + if kid, found := token.Header["kid"]; found { + kidStr = kid.(string) + } + claims := token.Claims.(*validatableClaims) + if claims.Issuer == "" { + return nil, &IssuerMissingError{&claims.Claims} + } + issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, claims.Issuer) + if !issuerURLFound { + return nil, &IssuerUntrustedError{&claims.Claims} + } + return p.keysProvider.GetRSAPublicKey(ctx, issuerURL, kidStr) + + default: + return nil, &SignAlgUnknownError{signAlg} + } + } +} + +func (p *Parser) getURLForIssuerWithCallback(ctx context.Context, issuer string) (string, bool) { + issURL, issFound := p.GetURLForIssuer(issuer) + if issFound { + return issURL, true + } + if p.trustedIssuerNotFoundFallback == nil { + return "", false + } + return p.trustedIssuerNotFoundFallback(ctx, p, issuer) +} + +// Claims represents an extended version of JWT claims. +type Claims struct { + jwtgo.RegisteredClaims + Scope []AccessPolicy `json:"scope,omitempty"` + Version int `json:"ver,omitempty"` + UserID string `json:"uid,omitempty"` + OriginID string `json:"origin,omitempty"` + ClientID string `json:"client_id,omitempty"` + TOTPTime int64 `json:"totp_time,omitempty"` + SubType string `json:"sub_type,omitempty"` + OwnerTenantUUID string `json:"owner_tuid,omitempty"` +} + +// AccessPolicy represents a single access policy. +type AccessPolicy struct { + // TenantID equals to tenant ID for which access is granted (if resource is not specified) + // or which resource is owned by (if resource is specified). + // max length is 36 characters (uuid) + TenantID string `json:"tid,omitempty"` + + // TenantUUID equals to tenant UUID for which access is granted (if resource is not specified) + // or which resource is owned by (if resource is specified). + // max length is 36 characters (uuid) + TenantUUID string `json:"tuid,omitempty"` + + // ResourceServerID must be unique resource server instance or cluster ID. + // max length is 36 characters [a-Z0-9-_] + ResourceServerID string `json:"rs,omitempty"` + + // ResourceNamespace AKA resource type, partitions resources within resource server. + // E.g.: storage, task-manager, account-server, resource-manager, policy-manager etc. + // max length is 36 characters [a-Z0-9-_] + ResourceNamespace string `json:"rn,omitempty"` + + // ResourcePath AKA resource ID AKA resource pointer, is a unique identifier of + // or path to (in scope of resource server and namespace) a single resource or resource collection + // 'path' notion remind that it can contain segments, each meaningfull to resource server + // i.e. each sub-path can correspond to different resources, and access policies can be assigned with any sub-path granularity + // but resource path will be considered as immutable, + // moving resources 'within' the path will break access control logic on both AuthZ server and resource server sides + // e.g: vms, vm1, queues, queue1 + // max length is 255 characters [a-Z0-9-_] + ResourcePath string `json:"rp,omitempty"` + + // Role - role available for the resource specified by resource id + Role string `json:"role,omitempty"` + + AllowPermissions []string `json:"allow,omitempty"` + DenyPermissions []string `json:"deny,omitempty"` +} + +type validatableClaims struct { + Claims + customValidator func(c *Claims) error +} + +func (v *validatableClaims) Validate() error { + if v.customValidator != nil { + return v.customValidator(&v.Claims) + } + return nil +} + +func makeCustomAudienceValidator(requireAudience bool, audienceMatchers []audienceMatcher) func(c *Claims) error { + return func(c *Claims) error { + if len(c.Audience) == 0 { + if requireAudience { + return fmt.Errorf("%w: %w", jwtgo.ErrTokenRequiredClaimMissing, &AudienceMissingError{c}) + } + return nil + } + + if len(audienceMatchers) == 0 { + return nil + } + for i := range audienceMatchers { + for j := range c.Audience { + if audienceMatchers[i](c.Audience[j]) { + return nil + } + } + } + return fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidAudience, &AudienceNotExpectedError{c}) + } +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 0000000..33c12a8 --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,353 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +const testIss = "test-issuer" + +func TestJWTParser_Parse(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + logger := log.NewDisabledLogger() + + t.Run("ok", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + TOTPTime: time.Now().Unix(), + SubType: "task_manager", + } + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims.Scope, parsedClaims.Scope) + require.Equal(t, claims.TOTPTime, parsedClaims.TOTPTime) + require.Equal(t, claims.SubType, parsedClaims.SubType) + }) + + t.Run("ok for empty kid", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) + }) + + t.Run("ok for trusted issuer url (glob pattern)", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: issuerConfigServer.URL, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + issURLs := []string{ + issuerConfigServer.URL, + "http://127.0.0.*", + "http://127.*", + } + for _, issURL := range issURLs { + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, parser.AddTrustedIssuerURL(issURL)) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) + } + }) + + t.Run("ok for expected audience (glob pattern)", func(t *testing.T) { + for _, aud := range []string{"region1.cloud.com", "region2.cloud.com"} { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Audience: []string{aud}, + Issuer: issuerConfigServer.URL, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + ExpectedAudience: []string{"*.cloud.com"}, + }) + require.NoError(t, parser.AddTrustedIssuerURL(issuerConfigServer.URL)) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) + } + }) + + t.Run("malformed jwt", func(t *testing.T) { + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err := parser.Parse(context.Background(), "invalid-jwt") + require.ErrorIs(t, err, jwtgo.ErrTokenMalformed) + require.ErrorContains(t, err, "token contains an invalid number of segments") + }) + + t.Run("unsigned jwt", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + token := jwtgo.NewWithClaims(jwtgo.SigningMethodNone, claims) + tokenString, err := token.SignedString(jwtgo.UnsafeAllowNoneSignatureType) + require.NoError(t, err) + + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err = parser.Parse(context.Background(), tokenString) + require.ErrorIs(t, err, jwtgo.NoneSignatureTypeDisallowedError) + }) + + t.Run("jwt issuer missing", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{Audience: []string{"https://cloud.acronis.com"}}, + } + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) + var issMissingErr *jwt.IssuerMissingError + require.ErrorAs(t, err, &issMissingErr) + require.Equal(t, claims, issMissingErr.Claims) + }) + + t.Run("jwt has untrusted issuer", func(t *testing.T) { + const issuer = "untrusted-issuer" + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) + var issUntrustedErr *jwt.IssuerUntrustedError + require.ErrorAs(t, err, &issUntrustedErr) + require.Equal(t, claims, issUntrustedErr.Claims) + }) + + t.Run("jwt has untrusted issuer url", func(t *testing.T) { + const issuer = "https://3rd-party-idp.com" + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, parser.AddTrustedIssuerURL("https://*.acronis.com")) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) + var issUntrustedErr *jwt.IssuerUntrustedError + require.ErrorAs(t, err, &issUntrustedErr) + require.Equal(t, claims, issUntrustedErr.Claims) + }) + + t.Run("jwt has untrusted issuer url, callback adds it to trusted", func(t *testing.T) { + var callbackCallCount int + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Audience: []string{issuerConfigServer.URL}, + Issuer: issuerConfigServer.URL, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + TrustedIssuerNotFoundFallback: func(ctx context.Context, p *jwt.Parser, iss string) (issURL string, issFound bool) { + callbackCallCount++ + addErr := p.AddTrustedIssuerURL(iss) + if addErr != nil { + return "", false + } + return iss, true + }, + }) + require.Equal(t, 0, callbackCallCount) + parsedClaims, pErr := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, pErr, "issuer must be considered as trusted and no error returned") + require.Equalf(t, 1, callbackCallCount, "Callback was not called by parser") + require.Equal(t, claims, parsedClaims) + parsedClaims, pErr = parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, pErr, "issuer must be considered as trusted and no error returned") + require.Equal(t, claims, parsedClaims) + require.Equalf(t, 1, callbackCallCount, "Callback should be called exactly once") + }) + + t.Run("jwt exp is missing", func(t *testing.T) { + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenRequiredClaimMissing) + }) + + t.Run("jwt expired", func(t *testing.T) { + expiresAt := time.Now().Add(-time.Second) + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(expiresAt)}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenExpired) + }) + + t.Run("jwt not valid yet", func(t *testing.T) { + notBefore := time.Now().Add(time.Minute) + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + NotBefore: jwtgo.NewNumericDate(notBefore), + }} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenNotValidYet) + }) + + t.Run("required jwt audience is missing", func(t *testing.T) { + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }} + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + RequireAudience: true, + }) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenRequiredClaimMissing) + var jwtErr *jwt.AudienceMissingError + require.ErrorAs(t, err, &jwtErr) + require.Equal(t, claims, jwtErr.Claims) + }) + + t.Run("jwt audience is not expected", func(t *testing.T) { + const audience = "not-expected-audience" + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + Audience: []string{audience}, + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }} + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + ExpectedAudience: []string{"expected-audience"}, + }) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidAudience) + var jwtErr *jwt.AudienceNotExpectedError + require.ErrorAs(t, err, &jwtErr) + require.Equal(t, claims, jwtErr.Claims) + }) + + t.Run("verification error", func(t *testing.T) { + jwksServer2 := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer2.Close() + + openIDCfgHandler2 := &idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer2.URL} + openIDCfgServer2 := httptest.NewServer(openIDCfgHandler2) + defer openIDCfgServer2.Close() + + const cacheUpdateMinInterval = time.Second + + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString, err := idptest.MakeTokenString(claims, "737c5114f09b5ed05276bd4b520245982f7fb29f", idptest.GetTestRSAPrivateKey()) + require.NoError(t, err) + jwksClient := jwks.NewCachingClientWithOpts(http.DefaultClient, logger, jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) + parser := jwt.NewParser(jwksClient, logger) + parser.AddTrustedIssuer(testIss, openIDCfgServer2.URL) + + for i := 0; i < 2; i++ { + _, err = parser.Parse(context.Background(), tokenString) + require.ErrorIs(t, err, jwtgo.ErrTokenSignatureInvalid) + require.EqualValues(t, 1, openIDCfgHandler2.ServedCount()) + require.EqualValues(t, 1, openIDCfgHandler2.ServedCount()) + } + + time.Sleep(cacheUpdateMinInterval * 2) + + _, err = parser.Parse(context.Background(), tokenString) + require.ErrorIs(t, err, jwtgo.ErrTokenSignatureInvalid) + require.EqualValues(t, 2, openIDCfgHandler2.ServedCount()) + require.EqualValues(t, 2, openIDCfgHandler2.ServedCount()) + }) +} + +func TestParser_getURLForIssuer(t *testing.T) { + tests := []struct { + Name string + IssURLPattern string + TrustedIssURLs []string + NotTrustedIssURLs []string + }{ + { + Name: "wildcard in host", + IssURLPattern: "https://*.acronis.com/bc", + TrustedIssURLs: []string{ + "https://us-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com/bc", + }, + NotTrustedIssURLs: []string{ + "http://eu2-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com", + "https://eu2-cloud.acronis.com/bc/foobar", + "https://my-site.com/eu2-cloud.acronis.com/bc", + "https://my-site.com?foo=eu2-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com/bc?foo=bar", + }, + }, + { + Name: "no wildcard in path", + IssURLPattern: "https://eu3-cloud.acronis.com/bc", + TrustedIssURLs: []string{"https://eu3-cloud.acronis.com/bc"}, + NotTrustedIssURLs: []string{ + "https://eu1-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com/bc", + }, + }, + } + for i := range tests { + tt := tests[i] + t.Run(tt.Name, func(t *testing.T) { + logger := log.NewDisabledLogger() + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, parser.AddTrustedIssuerURL(tt.IssURLPattern)) + for _, issURL := range tt.TrustedIssURLs { + u, ok := parser.GetURLForIssuer(issURL) + require.True(t, ok) + require.Equal(t, u, issURL) + } + for _, issURL := range tt.NotTrustedIssURLs { + _, ok := parser.GetURLForIssuer(issURL) + require.False(t, ok) + } + }) + } +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..d6e6f72 --- /dev/null +++ b/middleware.go @@ -0,0 +1,206 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/acronis/go-appkit/httpserver/middleware" + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/restapi" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" +) + +// HeaderAuthorization contains the name of HTTP header with data that is used for authentication and authorization. +const HeaderAuthorization = "Authorization" + +// Authentication and authorization error codes. +// We are using "var" here because some services may want to use different error codes. +var ( + ErrCodeBearerTokenMissing = "bearerTokenMissing" + ErrCodeAuthenticationFailed = "authenticationFailed" + ErrCodeAuthorizationFailed = "authorizationFailed" +) + +// Authentication error messages. +// We are using "var" here because some services may want to use different error messages. +var ( + ErrMessageBearerTokenMissing = "Authorization bearer token is missing." + ErrMessageAuthenticationFailed = "Authentication is failed." + ErrMessageAuthorizationFailed = "Authorization is failed." +) + +type ctxKey int + +const ( + ctxKeyJWTClaims ctxKey = iota + ctxKeyBearerToken +) + +// JWTParser is an interface for parsing string representation of JWT. +type JWTParser interface { + Parse(ctx context.Context, token string) (*jwt.Claims, error) +} + +// CachingJWTParser does the same as JWTParser but stores parsed JWT claims in cache. +type CachingJWTParser interface { + JWTParser + InvalidateCache(ctx context.Context) +} + +// TokenIntrospector is an interface for introspecting tokens. +type TokenIntrospector interface { + IntrospectToken(ctx context.Context, token string) (idptoken.IntrospectionResult, error) +} + +type jwtAuthHandler struct { + next http.Handler + errorDomain string + jwtParser JWTParser + verifyAccess func(r *http.Request, claims *jwt.Claims) bool + tokenIntrospector TokenIntrospector +} + +type jwtAuthMiddlewareOpts struct { + verifyAccess func(r *http.Request, claims *jwt.Claims) bool + tokenIntrospector TokenIntrospector +} + +// JWTAuthMiddlewareOption is an option for JWTAuthMiddleware. +type JWTAuthMiddlewareOption func(options *jwtAuthMiddlewareOpts) + +// WithJWTAuthMiddlewareVerifyAccess is an option to set a function that verifies access for JWTAuthMiddleware. +func WithJWTAuthMiddlewareVerifyAccess(verifyAccess func(r *http.Request, claims *jwt.Claims) bool) JWTAuthMiddlewareOption { + return func(options *jwtAuthMiddlewareOpts) { + options.verifyAccess = verifyAccess + } +} + +// WithJWTAuthMiddlewareTokenIntrospector is an option to set a token introspector for JWTAuthMiddleware. +func WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector TokenIntrospector) JWTAuthMiddlewareOption { + return func(options *jwtAuthMiddlewareOpts) { + options.tokenIntrospector = tokenIntrospector + } +} + +// JWTAuthMiddleware is a middleware that does authentication +// by Access Token from the "Authorization" HTTP header of incoming request. +func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthMiddlewareOption) func(next http.Handler) http.Handler { + var options jwtAuthMiddlewareOpts + for _, opt := range opts { + opt(&options) + } + return func(next http.Handler) http.Handler { + return &jwtAuthHandler{next, errorDomain, jwtParser, options.verifyAccess, options.tokenIntrospector} + } +} + +func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + reqCtx := r.Context() + logger := middleware.GetLoggerFromContext(reqCtx) + + bearerToken := GetBearerTokenFromRequest(r) + if bearerToken == "" { + apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + + var jwtClaims *jwt.Claims + if h.tokenIntrospector != nil { + if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil { + switch { + case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded): + // Do nothing. Access Token already contains all necessary information for authN/authZ. + case errors.Is(err, idptoken.ErrTokenNotIntrospectable): + // Token is not introspectable by some reason. + // In this case, we will parse it as JWT and use it for authZ. + if logger != nil { + logger.Warn("token is not introspectable, it will be used for authentication and authorization as is", + log.Error(err)) + } + default: + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + } else { + if !introspectionResult.Active { + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + jwtClaims = &introspectionResult.Claims + } + } + + if jwtClaims == nil { + var err error + if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil { + if logger != nil { + logger.Error("authentication failed", log.Error(err)) + } + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + } + + if h.verifyAccess != nil { + if !h.verifyAccess(r, jwtClaims) { + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed) + restapi.RespondError(rw, http.StatusForbidden, apiErr, logger) + return + } + } + + reqCtx = NewContextWithBearerToken(reqCtx, bearerToken) + reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims) + h.next.ServeHTTP(rw, r.WithContext(reqCtx)) +} + +// GetBearerTokenFromRequest extracts jwt token from request headers. +func GetBearerTokenFromRequest(r *http.Request) string { + authHeader := strings.TrimSpace(r.Header.Get(HeaderAuthorization)) + if strings.HasPrefix(authHeader, "Bearer ") || strings.HasPrefix(authHeader, "bearer ") { + return authHeader[7:] + } + return "" +} + +// NewContextWithJWTClaims creates a new context with JWT claims. +func NewContextWithJWTClaims(ctx context.Context, jwtClaims *jwt.Claims) context.Context { + return context.WithValue(ctx, ctxKeyJWTClaims, jwtClaims) +} + +// GetJWTClaimsFromContext extracts JWT claims from the context. +func GetJWTClaimsFromContext(ctx context.Context) *jwt.Claims { + value := ctx.Value(ctxKeyJWTClaims) + if value == nil { + return nil + } + return value.(*jwt.Claims) +} + +// NewContextWithBearerToken creates a new context with token. +func NewContextWithBearerToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, ctxKeyBearerToken, token) +} + +// GetBearerTokenFromContext extracts token from the context. +func GetBearerTokenFromContext(ctx context.Context) string { + value := ctx.Value(ctxKeyBearerToken) + if value == nil { + return "" + } + return value.(string) +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..45c58c5 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,242 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/acronis/go-appkit/testutil" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" +) + +type mockJWTAuthMiddlewareNextHandler struct { + called int + jwtClaims *jwt.Claims +} + +func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + h.called++ + h.jwtClaims = GetJWTClaimsFromContext(r.Context()) +} + +type mockJWTParser struct { + parseCalled int + claimsToReturn *jwt.Claims + errToReturn error + passedToken string +} + +func (p *mockJWTParser) Parse(_ context.Context, token string) (*jwt.Claims, error) { + p.parseCalled++ + p.passedToken = token + return p.claimsToReturn, p.errToReturn +} + +type mockTokenIntrospector struct { + introspectCalled int + introspectedToken string + resultToReturn idptoken.IntrospectionResult + errToReturn error +} + +func (i *mockTokenIntrospector) IntrospectToken(_ context.Context, token string) (idptoken.IntrospectionResult, error) { + i.introspectCalled++ + i.introspectedToken = token + return i.resultToReturn, i.errToReturn +} + +func TestJWTAuthMiddleware(t *testing.T) { + const errDomain = "TestDomain" + + t.Run("bearer token is missing", func(t *testing.T) { + for _, headerVal := range []string{"", "foobar", "Bearer", "Bearer "} { + parser := &mockJWTParser{} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + if headerVal != "" { + req.Header.Set(HeaderAuthorization, headerVal) + } + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeBearerTokenMissing) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 0, next.called) + require.Nil(t, next.jwtClaims) + } + }) + + t.Run("authentication failed", func(t *testing.T) { + parser := &mockJWTParser{errToReturn: errors.New("malformed JWT")} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer foobar") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 0, next.called) + require.Nil(t, next.jwtClaims) + }) + + t.Run("ok", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) + + t.Run("introspection failed", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + introspector := &mockTokenIntrospector{errToReturn: errors.New("introspection failed")} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 0, next.called) + }) + + t.Run("introspection is not needed", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenIntrospectionNotNeeded} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) + + t.Run("ok, token is not introspectable", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenNotIntrospectable} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) + + t.Run("authentication failed, token is introspected but inactive", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{} + introspector := &mockTokenIntrospector{resultToReturn: idptoken.IntrospectionResult{Active: false}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 0, next.called) + }) + + t.Run("ok, token is introspected and active", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{} + introspector := &mockTokenIntrospector{resultToReturn: idptoken.IntrospectionResult{Active: true, Claims: jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) +} + +func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { + const errDomain = "TestDomain" + + t.Run("authorization failed", func(t *testing.T) { + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, ErrCodeAuthorizationFailed) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 0, next.called) + require.Nil(t, next.jwtClaims) + }) + + t.Run("ok", func(t *testing.T) { + scope := []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}} + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{Scope: scope}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.EqualValues(t, scope, next.jwtClaims.Scope) + }) +} diff --git a/min-coverage.txt b/min-coverage.txt new file mode 100644 index 0000000..d15a2cc --- /dev/null +++ b/min-coverage.txt @@ -0,0 +1 @@ +80