From 83ba1e791f1aff79516de7608a0a0974082434c6 Mon Sep 17 00:00:00 2001 From: Vasily Tsybenko Date: Wed, 2 Oct 2024 13:31:48 +0300 Subject: [PATCH] Initial commit --- .gitignore | 3 + .golangci.yml | 78 ++ .trufflehog3.yml | 9 + CODEOWNERS | 1 + LICENSE | 21 + README.md | 137 ++++ auth.go | 312 ++++++++ auth_test.go | 411 +++++++++++ config.go | 296 ++++++++ config_test.go | 265 +++++++ doc.go | 8 + example_test.go | 196 +++++ go.mod | 58 ++ go.sum | 221 ++++++ idptest/doc.go | 10 + idptest/grpc_server.go | 124 ++++ idptest/http_server.go | 175 +++++ idptest/jwks_handler.go | 100 +++ idptest/jwt.go | 86 +++ idptest/jwt_test.go | 54 ++ idptest/openid_configuration_handler.go | 60 ++ idptest/token_handlers.go | 103 +++ idptest/token_provider.go | 26 + idptoken/caching_introspector.go | 194 +++++ idptoken/caching_introspector_test.go | 255 +++++++ idptoken/config.go | 61 ++ idptoken/config_test.go | 105 +++ idptoken/doc.go | 10 + idptoken/grpc_client.go | 182 +++++ idptoken/idp_token.proto | 74 ++ idptoken/introspector.go | 402 ++++++++++ idptoken/introspector_test.go | 306 ++++++++ idptoken/pb/idp_token.pb.go | 682 +++++++++++++++++ idptoken/pb/idp_token_grpc.pb.go | 154 ++++ idptoken/provider.go | 692 ++++++++++++++++++ idptoken/provider_test.go | 464 ++++++++++++ internal/idputil/doc.go | 9 + internal/idputil/openid_configuration.go | 72 ++ internal/idputil/trusted_issuers_store.go | 80 ++ internal/libinfo/doc.go | 8 + internal/libinfo/lib_info.go | 26 + internal/libinfo/version.go | 33 + internal/metrics/doc.go | 8 + internal/metrics/metrics.go | 174 +++++ internal/testing/doc.go | 8 + .../testing/server_token_introspector_mock.go | 152 ++++ jwks/caching_client.go | 169 +++++ jwks/caching_client_test.go | 110 +++ jwks/client.go | 138 ++++ jwks/client_test.go | 160 ++++ jwks/doc.go | 8 + jwks/errors.go | 49 ++ jwt/caching_parser.go | 113 +++ jwt/caching_parser_test.go | 109 +++ jwt/doc.go | 8 + jwt/errors.go | 56 ++ jwt/jwt.go | 255 +++++++ jwt/jwt_test.go | 353 +++++++++ middleware.go | 206 ++++++ middleware_test.go | 242 ++++++ min-coverage.txt | 1 + 61 files changed, 8882 insertions(+) create mode 100644 .gitignore create mode 100644 .golangci.yml create mode 100644 .trufflehog3.yml create mode 100644 CODEOWNERS create mode 100644 LICENSE create mode 100644 README.md create mode 100644 auth.go create mode 100644 auth_test.go create mode 100644 config.go create mode 100644 config_test.go create mode 100644 doc.go create mode 100644 example_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 idptest/doc.go create mode 100644 idptest/grpc_server.go create mode 100644 idptest/http_server.go create mode 100644 idptest/jwks_handler.go create mode 100644 idptest/jwt.go create mode 100644 idptest/jwt_test.go create mode 100644 idptest/openid_configuration_handler.go create mode 100644 idptest/token_handlers.go create mode 100644 idptest/token_provider.go create mode 100644 idptoken/caching_introspector.go create mode 100644 idptoken/caching_introspector_test.go create mode 100644 idptoken/config.go create mode 100644 idptoken/config_test.go create mode 100644 idptoken/doc.go create mode 100644 idptoken/grpc_client.go create mode 100644 idptoken/idp_token.proto create mode 100644 idptoken/introspector.go create mode 100644 idptoken/introspector_test.go create mode 100644 idptoken/pb/idp_token.pb.go create mode 100644 idptoken/pb/idp_token_grpc.pb.go create mode 100644 idptoken/provider.go create mode 100644 idptoken/provider_test.go create mode 100644 internal/idputil/doc.go create mode 100644 internal/idputil/openid_configuration.go create mode 100644 internal/idputil/trusted_issuers_store.go create mode 100644 internal/libinfo/doc.go create mode 100644 internal/libinfo/lib_info.go create mode 100644 internal/libinfo/version.go create mode 100644 internal/metrics/doc.go create mode 100644 internal/metrics/metrics.go create mode 100644 internal/testing/doc.go create mode 100644 internal/testing/server_token_introspector_mock.go create mode 100644 jwks/caching_client.go create mode 100644 jwks/caching_client_test.go create mode 100644 jwks/client.go create mode 100644 jwks/client_test.go create mode 100644 jwks/doc.go create mode 100644 jwks/errors.go create mode 100644 jwt/caching_parser.go create mode 100644 jwt/caching_parser_test.go create mode 100644 jwt/doc.go create mode 100644 jwt/errors.go create mode 100644 jwt/jwt.go create mode 100644 jwt/jwt_test.go create mode 100644 middleware.go create mode 100644 middleware_test.go create mode 100644 min-coverage.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9ed0094 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea/ +.vscode/ +vendor/ \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..acf3fe4 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,78 @@ +linters-settings: + gocyclo: + min-complexity: 25 + goconst: + min-len: 2 + min-occurrences: 2 + misspell: + locale: US + lll: + line-length: 140 + goimports: + local-prefixes: github.com/acronis/go-authkit/ + gocritic: + enabled-tags: + - diagnostic + - performance + - style + - experimental + disabled-checks: + - whyNoLint + - paramTypeCombine + - sloppyReassign + settings: + hugeParam: + sizeThreshold: 256 + rangeValCopy: + sizeThreshold: 256 + funlen: + lines: 120 + statements: 60 + +linters: + disable-all: true + enable: + - bodyclose + - dogsled + - errcheck + - exportloopref + - funlen + - gochecknoinits + - goconst + - gocritic + - gocyclo + - gofmt + - goimports + - gosec + - gosimple + - govet + - ineffassign + - lll + - misspell + - nakedret + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - whitespace + +issues: + # Don't use default excluding to be sure all exported things (method, functions, consts and so on) have comments. + exclude-use-default: false + exclude-rules: + - path: _test\.go + linters: + - dogsled + - ineffassign + - funlen + - gocritic + - gocyclo + - gosec + - goconst + - govet + - lll + - staticcheck + - unused + - unparam diff --git a/.trufflehog3.yml b/.trufflehog3.yml new file mode 100644 index 0000000..b02d531 --- /dev/null +++ b/.trufflehog3.yml @@ -0,0 +1,9 @@ +exclude: + - message: Exclude values in test files and primitives for testing + paths: + - idptest/jwks_handler.go + - jwt/jwt_test.go + + - message: Skip false positive high-entropy sequences + id: high-entropy + pattern: cfgKeyIntrospectionGRPCTLSEnabled diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..027c494 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @vasayxtx @MikeYast \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..807b16f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright © 2024 Acronis International GmbH. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..3b16c7b --- /dev/null +++ b/README.md @@ -0,0 +1,137 @@ +# Simple library in Go with primitives for performing authentication and authorization + +The library includes the following packages: ++ `auth` (root directory) - provides authentication and authorization primitives for using on the server side. ++ `jwt` - provides parser for JSON Web Tokens (JWT). ++ `jwks` - provides a client for fetching and caching JSON Web Key Sets (JWKS). ++ `idptoken` - provides a client for fetching and caching Access Tokens from Identity Providers (IDP). ++ `idptest` - provides primitives for testing IDP clients. + +## Examples + +### Authenticating requests with JWT tokens + +The `JWTAuthMiddleware` function creates a middleware that authenticates requests with JWT tokens. + +It uses the `JWTParser` to parse and validate JWT. +`JWTParser` can verify JWT tokens signed with RSA (RS256, RS384, RS512) algorithms for now. +It performs /.well-known/openid-configuration request to get the JWKS URL ("jwks_uri" field) and fetches JWKS from there. +For other algorithms `jwt.SignAlgUnknownError` error will be returned. +The `JWTParser` can be created with the `NewJWTParser` function or with the `NewJWTParserWithCachingJWKS` function. +The last one is recommended for production use because it caches public keys (JWKS) that are used for verifying JWT tokens. + +See `Config` struct for more customization options. + +Example: + +```go +package main + +import ( + "net/http" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-authkit" +) + +func main() { + jwtConfig := auth.JWTConfig{ + TrustedIssuerURLs: []string{"https://my-idp.com"}, + //TrustedIssuers: map[string]string{"my-idp": "https://my-idp.com"}, // Use TrustedIssuers if you have a custom issuer name. + } + jwtParser, _ := auth.NewJWTParserWithCachingJWKS(&auth.Config{JWT: jwtConfig}, log.NewDisabledLogger()) + authN := auth.JWTAuthMiddleware("MyService", jwtParser) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authN(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + //jwtClaims := GetJWTClaimsFromContext(r.Context()) // GetJWTClaimsFromContext is a helper function to get JWT claims from context. + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + _ = http.ListenAndServe(":8080", srvMux) +} +``` + +```shell +$ curl -w "\nHTTP code: %{http_code}\n" localhost:8080 +Hello, World! +HTTP code: 200 + +$ curl -w "\nHTTP code: %{http_code}\n" localhost:8080/admin +{"error":{"domain":"MyService","code":"bearerTokenMissing","message":"Authorization bearer token is missing."}} +HTTP code: 401 +``` + +### Authorizing requests with JWT tokens + +```go +package main + +import ( + "net/http" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-authkit" +) + +func main() { + jwtConfig := auth.JWTConfig{TrustedIssuers: map[string]string{"my-idp": idpURL}} + jwtParser, _ := auth.NewJWTParserWithCachingJWKS(&auth.Config{JWT: jwtConfig}, log.NewDisabledLogger()) + authOnlyAdmin := auth.JWTAuthMiddlewareWithVerifyAccess("MyService", jwtParser, + auth.NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"})) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authOnlyAdmin(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + _ = http.ListenAndServe(":8080", srvMux) +} +``` + +Please see [example_test.go](./example_test.go) for a full version of the example. + +### Fetching and caching Access Tokens from Identity Providers + +The `idptoken.Provider` object is used to fetch and cache Access Tokens from Identity Providers (IDP). + +Example: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/acronis/go-authkit/idptoken" +) + +func main() { + // ... + httpClient := &http.Client{Timeout: 30 * time.Second} + source := idptoken.Source{ + URL: idpURL, + ClientID: clientID, + ClientSecret: clientSecret, + } + provider := idptoken.NewProvider(httpClient, source) + accessToken, err := provider.GetToken() + if err != nil { + log.Fatalf("failed to get access token: %v", err) + } + // ... +} +``` + +## License + +Copyright © 2024 Acronis International GmbH. + +Licensed under [MIT License](./LICENSE). diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..8737a9a --- /dev/null +++ b/auth.go @@ -0,0 +1,312 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "time" + + "github.com/acronis/go-appkit/log" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +// Default values. +const ( + DefaultHTTPClientRequestTimeout = time.Second * 30 + DefaultGRPCClientRequestTimeout = time.Second * 30 +) + +// NewJWTParser creates a new JWTParser with the given configuration. +// If cfg.JWT.ClaimsCache.Enabled is true, then jwt.CachingParser created, otherwise - jwt.Parser. +func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { + var options jwtParserOptions + for _, opt := range opts { + opt(&options) + } + logger := options.logger + if logger == nil { + logger = log.NewDisabledLogger() + } + + // Make caching JWKS client. + jwksCacheUpdateMinInterval := cfg.JWKS.Cache.UpdateMinInterval + if jwksCacheUpdateMinInterval == 0 { + jwksCacheUpdateMinInterval = jwks.DefaultCacheUpdateMinInterval + } + httpClientRequestTimeout := cfg.HTTPClient.RequestTimeout + if httpClientRequestTimeout == 0 { + httpClientRequestTimeout = DefaultHTTPClientRequestTimeout + } + jwksClientOpts := jwks.CachingClientOpts{ + ClientOpts: jwks.ClientOpts{PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel}, + CacheUpdateMinInterval: jwksCacheUpdateMinInterval, + } + jwksClient := jwks.NewCachingClientWithOpts(&http.Client{Timeout: httpClientRequestTimeout}, logger, jwksClientOpts) + + // Make JWT parser. + + if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { + logger.Warn("list of trusted issuers is empty, jwt parsing may not work properly") + } + + parserOpts := jwt.ParserOpts{ + RequireAudience: cfg.JWT.RequireAudience, + ExpectedAudience: cfg.JWT.ExpectedAudience, + TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, + } + + if cfg.JWT.ClaimsCache.Enabled { + cachingJWTParser, err := jwt.NewCachingParserWithOpts(jwksClient, logger, jwt.CachingParserOpts{ + ParserOpts: parserOpts, + CacheMaxEntries: cfg.JWT.ClaimsCache.MaxEntries, + }) + if err != nil { + return nil, fmt.Errorf("new caching JWT parser: %w", err) + } + if err = addTrustedIssuers(cachingJWTParser, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return cachingJWTParser, nil + } + + jwtParser := jwt.NewParserWithOpts(jwksClient, logger, parserOpts) + if err := addTrustedIssuers(jwtParser, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return jwtParser, nil +} + +type jwtParserOptions struct { + logger log.FieldLogger + prometheusLibInstanceLabel string + trustedIssuerNotFoundFallback jwt.TrustedIssNotFoundFallback +} + +// JWTParserOption is an option for creating JWTParser. +type JWTParserOption func(options *jwtParserOptions) + +// WithJWTParserLogger sets the logger for JWTParser. +func WithJWTParserLogger(logger log.FieldLogger) JWTParserOption { + return func(options *jwtParserOptions) { + options.logger = logger + } +} + +// WithJWTParserPrometheusLibInstanceLabel sets the Prometheus lib instance label for JWTParser. +func WithJWTParserPrometheusLibInstanceLabel(label string) JWTParserOption { + return func(options *jwtParserOptions) { + options.prometheusLibInstanceLabel = label + } +} + +// WithJWTParserTrustedIssuerNotFoundFallback sets the fallback for JWTParser when trusted issuer is not found. +func WithJWTParserTrustedIssuerNotFoundFallback(fallback jwt.TrustedIssNotFoundFallback) JWTParserOption { + return func(options *jwtParserOptions) { + options.trustedIssuerNotFoundFallback = fallback + } +} + +// NewTokenIntrospector creates a new TokenIntrospector with the given configuration, token provider and scope filter. +// If cfg.Introspection.ClaimsCache.Enabled or cfg.Introspection.NegativeCache.Enabled is true, +// then idptoken.CachingIntrospector created, otherwise - idptoken.Introspector. +// Please note that the tokenProvider should be able to provide access token with the policy for introspection. +// scopeFilter is a list of filters that will be applied to the introspected token. +func NewTokenIntrospector( + cfg *Config, + tokenProvider idptoken.IntrospectionTokenProvider, + scopeFilter []idptoken.IntrospectionScopeFilterAccessPolicy, + opts ...TokenIntrospectorOption, +) (TokenIntrospector, error) { + var options tokenIntrospectorOptions + for _, opt := range opts { + opt(&options) + } + logger := options.logger + if logger == nil { + logger = log.NewDisabledLogger() + } + + if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { + logger.Warn("list of trusted issuers is empty, jwt introspection may not work properly") + } + + var grpcClient *idptoken.GRPCClient + if cfg.Introspection.GRPC.Target != "" { + transportCreds, err := makeGRPCTransportCredentials(cfg.Introspection.GRPC.TLS) + if err != nil { + return nil, fmt.Errorf("make grpc transport credentials: %w", err) + } + grpcClient, err = idptoken.NewGRPCClientWithOpts(cfg.Introspection.GRPC.Target, transportCreds, + idptoken.GRPCClientOpts{RequestTimeout: cfg.GRPCClient.RequestTimeout, Logger: logger}) + if err != nil { + return nil, fmt.Errorf("new grpc client: %w", err) + } + } + + httpClientRequestTimeout := cfg.HTTPClient.RequestTimeout + if httpClientRequestTimeout == 0 { + httpClientRequestTimeout = DefaultHTTPClientRequestTimeout + } + + introspectorOpts := idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: cfg.Introspection.Endpoint, + GRPCClient: grpcClient, + HTTPClient: &http.Client{Timeout: httpClientRequestTimeout}, + AccessTokenScope: cfg.Introspection.AccessTokenScope, + Logger: logger, + MinJWTVersion: cfg.Introspection.MinJWTVersion, + ScopeFilter: scopeFilter, + TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, + PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel, + } + + if cfg.Introspection.ClaimsCache.Enabled || cfg.Introspection.NegativeCache.Enabled { + cachingIntrospector, err := idptoken.NewCachingIntrospectorWithOpts(tokenProvider, idptoken.CachingIntrospectorOpts{ + IntrospectorOpts: introspectorOpts, + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{ + Enabled: cfg.Introspection.ClaimsCache.Enabled, + MaxEntries: cfg.Introspection.ClaimsCache.MaxEntries, + TTL: cfg.Introspection.ClaimsCache.TTL, + }, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{ + Enabled: cfg.Introspection.NegativeCache.Enabled, + MaxEntries: cfg.Introspection.NegativeCache.MaxEntries, + TTL: cfg.Introspection.NegativeCache.TTL, + }, + }) + if err != nil { + return nil, fmt.Errorf("new caching introspector: %w", err) + } + if err = addTrustedIssuers(cachingIntrospector, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return cachingIntrospector, nil + } + + introspector := idptoken.NewIntrospectorWithOpts(tokenProvider, introspectorOpts) + if err := addTrustedIssuers(introspector, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { + return nil, err + } + return introspector, nil +} + +type tokenIntrospectorOptions struct { + logger log.FieldLogger + prometheusLibInstanceLabel string + trustedIssuerNotFoundFallback idptoken.TrustedIssNotFoundFallback +} + +// TokenIntrospectorOption is an option for creating TokenIntrospector. +type TokenIntrospectorOption func(options *tokenIntrospectorOptions) + +// WithTokenIntrospectorLogger sets the logger for TokenIntrospector. +func WithTokenIntrospectorLogger(logger log.FieldLogger) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.logger = logger + } +} + +// WithTokenIntrospectorPrometheusLibInstanceLabel sets the Prometheus lib instance label for TokenIntrospector. +func WithTokenIntrospectorPrometheusLibInstanceLabel(label string) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.prometheusLibInstanceLabel = label + } +} + +// WithTokenIntrospectorTrustedIssuerNotFoundFallback sets the fallback for TokenIntrospector +// when trusted issuer is not found. +func WithTokenIntrospectorTrustedIssuerNotFoundFallback( + fallback idptoken.TrustedIssNotFoundFallback, +) TokenIntrospectorOption { + return func(options *tokenIntrospectorOptions) { + options.trustedIssuerNotFoundFallback = fallback + } +} + +// Role is a representation of role which may be used for verifying access. +type Role struct { + Namespace string + Name string +} + +// NewVerifyAccessByRolesInJWT creates a new function which may be used for verifying access by roles in JWT scope. +func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims *jwt.Claims) bool { + return func(_ *http.Request, claims *jwt.Claims) bool { + for i := range roles { + for j := range claims.Scope { + if roles[i].Name == claims.Scope[j].Role && roles[i].Namespace == claims.Scope[j].ResourceNamespace { + return true + } + } + } + return false + } +} + +// NewVerifyAccessByRolesInJWTMaker creates a new function which may be used for verifying access by roles in JWT scope given a namespace. +func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool { + return func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool { + roles := make([]Role, 0, len(roleNames)) + for i := range roleNames { + roles = append(roles, Role{Namespace: namespace, Name: roleNames[i]}) + } + return NewVerifyAccessByRolesInJWT(roles...) + } +} + +type issuerParser interface { + AddTrustedIssuer(issName string, issURL string) + AddTrustedIssuerURL(issURL string) error +} + +func addTrustedIssuers(issParser issuerParser, issuers map[string]string, issuerURLs []string) error { + for issName, issURL := range issuers { + issParser.AddTrustedIssuer(issName, issURL) + } + for _, issURL := range issuerURLs { + if err := issParser.AddTrustedIssuerURL(issURL); err != nil { + return fmt.Errorf("add trusted issuer URL: %w", err) + } + } + return nil +} + +func makeGRPCTransportCredentials(tlsCfg GRPCTLSConfig) (credentials.TransportCredentials, error) { + if !tlsCfg.Enabled { + return insecure.NewCredentials(), nil + } + + config := &tls.Config{} // nolint: gosec // TLS 1.2 is used by default. + if tlsCfg.CACert != "" { + caCert, err := os.ReadFile(tlsCfg.CACert) + if err != nil { + return nil, fmt.Errorf("read CA's certificate: %w", err) + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to add CA's certificate") + } + config.RootCAs = certPool + } + if tlsCfg.ClientCert != "" && tlsCfg.ClientKey != "" { + clientCert, err := tls.LoadX509KeyPair(tlsCfg.ClientCert, tlsCfg.ClientKey) + if err != nil { + return nil, fmt.Errorf("load client's certificate and key: %w", err) + } + config.Certificates = []tls.Certificate{clientCert} + } + return credentials.NewTLS(config), nil +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..f327704 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,411 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + gotesting "testing" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/testing" + "github.com/acronis/go-authkit/jwt" +) + +func TestNewJWTParser(t *gotesting.T) { + const testIss = "test-issuer" + + idpSrv := idptest.NewHTTPServer() + require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = idpSrv.Shutdown(context.Background()) }() + + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: idpSrv.URL(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "ro_admin"}}, + } + token := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + claimsWithNamedIssuer := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}}, + } + tokenWithNamedIssuer := idptest.MustMakeTokenStringSignedWithTestKey(claimsWithNamedIssuer) + + tests := []struct { + name string + token string + cfg *Config + expectedClaims *jwt.Claims + checkFn func(t *gotesting.T, jwtParser JWTParser) + }{ + { + name: "new jwt parser, trusted issuers map", + cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: idpSrv.URL()}}}, + token: tokenWithNamedIssuer, + expectedClaims: claimsWithNamedIssuer, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.Parser{}, jwtParser) + }, + }, + { + name: "new jwt parser, trusted issuer urls", + cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{idpSrv.URL()}}}, + token: token, + expectedClaims: claims, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.Parser{}, jwtParser) + }, + }, + { + name: "new caching jwt parser, trusted issuers map", + cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: idpSrv.URL()}, ClaimsCache: ClaimsCacheConfig{Enabled: true}}}, + token: tokenWithNamedIssuer, + expectedClaims: claimsWithNamedIssuer, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.CachingParser{}, jwtParser) + cachingParser := jwtParser.(*jwt.CachingParser) + require.Equal(t, 1, cachingParser.ClaimsCache.Len()) + }, + }, + { + name: "new caching jwt parser, trusted issuer urls", + cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{idpSrv.URL()}, ClaimsCache: ClaimsCacheConfig{Enabled: true}}}, + token: token, + expectedClaims: claims, + checkFn: func(t *gotesting.T, jwtParser JWTParser) { + require.IsType(t, &jwt.CachingParser{}, jwtParser) + cachingParser := jwtParser.(*jwt.CachingParser) + require.Equal(t, 1, cachingParser.ClaimsCache.Len()) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + jwtParser, err := NewJWTParser(tt.cfg) + require.NoError(t, err) + + parsedClaims, err := jwtParser.Parse(context.Background(), tt.token) + require.NoError(t, err) + require.Equal(t, tt.expectedClaims, parsedClaims) + if tt.checkFn != nil { + tt.checkFn(t, jwtParser) + } + }) + } +} + +func TestNewTokenIntrospector(t *gotesting.T) { + const testIss = "test-issuer" + + httpServerIntrospector := testing.NewHTTPServerTokenIntrospectorMock() + grpcServerIntrospector := testing.NewGRPCServerTokenIntrospectorMock() + + // Start testing HTTP IDP server. + httpIDPSrv := idptest.NewHTTPServer(idptest.WithHTTPTokenIntrospector(httpServerIntrospector)) + require.NoError(t, httpIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = httpIDPSrv.Shutdown(context.Background()) }() + + // Generate a self-signed certificate for the testing gRPC IDP server and start it. + tlsCert, certPEM, _ := generateSelfSignedRSACert(t) + certFile := filepath.Join(t.TempDir(), "cert.pem") + require.NoError(t, os.WriteFile(certFile, certPEM, 0644)) + grpcIDPSrv := idptest.NewGRPCServer( + idptest.WithGRPCTokenIntrospector(grpcServerIntrospector), + idptest.WithGRPCServerOptions(grpc.Creds(credentials.NewServerTLSFromCert(&tlsCert)))) + require.NoError(t, grpcIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { grpcIDPSrv.GracefulStop() }() + + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "ro_admin"}}, + } + token := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + claimsWithNamedIssuer := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}}, + } + tokenWithNamedIssuer := idptest.MustMakeTokenStringSignedWithTestKey(claimsWithNamedIssuer) + + opaqueToken := "opaque-token-" + uuid.NewString() + opaqueTokenScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) + grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{ + Active: true, TokenType: idptoken.TokenTypeBearer, Scope: []*pb.AccessTokenScope{ + { + TenantUuid: opaqueTokenScope[0].TenantUUID, + ResourceNamespace: opaqueTokenScope[0].ResourceNamespace, + RoleName: opaqueTokenScope[0].Role, + ResourcePath: opaqueTokenScope[0].ResourcePath, + }, + }}) + + tests := []struct { + name string + cfg *Config + token string + expectedResult idptoken.IntrospectionResult + checkFn func(t *gotesting.T, introspector TokenIntrospector) + }{ + { + name: "new token introspector, dynamic endpoint, trusted issuers map", + cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}}, + token: tokenWithNamedIssuer, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claimsWithNamedIssuer, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.Introspector{}, introspector) + }, + }, + { + name: "new token introspector, dynamic endpoint, trusted issuer urls", + cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}}, + token: token, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claims, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.Introspector{}, introspector) + }, + }, + { + name: "new caching token introspector, dynamic endpoint, trusted issuers map", + cfg: &Config{ + JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: httpIDPSrv.URL()}}, + Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}}, + }, + token: tokenWithNamedIssuer, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claimsWithNamedIssuer, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.CachingIntrospector{}, introspector) + cachingIntrospector := introspector.(*idptoken.CachingIntrospector) + require.Equal(t, 1, cachingIntrospector.ClaimsCache.Len(context.Background())) + }, + }, + { + name: "new caching token introspector, dynamic endpoint, trusted issuer urls", + cfg: &Config{ + JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, + Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}}, + }, + token: token, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: *claims, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.CachingIntrospector{}, introspector) + cachingIntrospector := introspector.(*idptoken.CachingIntrospector) + require.Equal(t, 1, cachingIntrospector.ClaimsCache.Len(context.Background())) + }, + }, + { + name: "new caching token introspector, static http endpoint", + cfg: &Config{ + Introspection: IntrospectionConfig{ + Enabled: true, + ClaimsCache: IntrospectionCacheConfig{Enabled: true}, + Endpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.CachingIntrospector{}, introspector) + cachingIntrospector := introspector.(*idptoken.CachingIntrospector) + require.Equal(t, 1, cachingIntrospector.ClaimsCache.Len(context.Background())) + }, + }, + { + name: "new token introspector, gRPC target, tls enabled", + cfg: &Config{ + JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, + Introspection: IntrospectionConfig{ + Enabled: true, + GRPC: IntrospectionGRPCConfig{ + Target: grpcIDPSrv.Addr(), + TLS: GRPCTLSConfig{ + Enabled: true, + CACert: certFile, + }, + }, + Endpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: "bearer", + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + checkFn: func(t *gotesting.T, introspector TokenIntrospector) { + require.IsType(t, &idptoken.Introspector{}, introspector) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + jwtParser, err := NewJWTParser(tt.cfg) + require.NoError(t, err) + httpServerIntrospector.JWTParser = jwtParser + grpcServerIntrospector.JWTParser = jwtParser + + introspector, err := NewTokenIntrospector(tt.cfg, idptest.NewSimpleTokenProvider("access-token"), nil) + require.NoError(t, err) + + result, err := introspector.IntrospectToken(context.Background(), tt.token) + require.NoError(t, err) + require.Equal(t, tt.expectedResult, result) + if tt.checkFn != nil { + tt.checkFn(t, introspector) + } + }) + } +} + +func TestNewVerifyAccessByJWTRoles(t *gotesting.T) { + jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{ + {ResourceNamespace: "policy_manager", Role: "admin"}, + {ResourceNamespace: "scan_service", Role: "admin"}, + {Role: "backup_user"}, + {ResourceNamespace: "agent_manager", Role: "agent_viewer"}, + }} + cases := []struct { + roles []Role + want bool + }{ + {[]Role{{Name: "tenant_viewer"}}, false}, + {[]Role{{Name: "backup_user"}}, true}, + {[]Role{{Namespace: "alert_manager", Name: "admin"}}, false}, + {[]Role{{Namespace: "policy_manager", Name: "admin"}}, true}, + {[]Role{{Namespace: "alert_manager", Name: "admin"}, {Name: "tenant_viewer"}}, false}, + {[]Role{{Namespace: "alert_manager", Name: "admin"}, {Name: "tenant_viewer"}, {Namespace: "policy_manager", Name: "admin"}}, true}, + } + for _, c := range cases { + got := NewVerifyAccessByRolesInJWT(c.roles...)(httptest.NewRequest(http.MethodGet, "/", nil), jwtClaims) + require.Equal(t, c.want, got, "want %v, got %v, roles %+v", c.want, got, c.roles) + } +} + +func TestNewVerifyAccessByJWTRolesMaker(t *gotesting.T) { + jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{ + {ResourceNamespace: "policy_manager", Role: "admin"}, + {ResourceNamespace: "scan_service", Role: "admin"}, + {Role: "backup_user"}, + {ResourceNamespace: "agent_manager", Role: "agent_viewer"}, + {ResourceNamespace: "agent_manager", Role: "agent_registrar"}, + }} + cases := []struct { + roleNamespace string + roleNames []string + want bool + }{ + {"agent_manager", []string{"agent_viewer", "agent_registrar"}, true}, + {"policy_manager", []string{"admin"}, true}, + {"", []string{"backup_user"}, true}, + {"alert_manager", []string{"admin"}, false}, + } + + for _, c := range cases { + got := NewVerifyAccessByRolesInJWTMaker(c.roleNamespace)(c.roleNames...)(httptest.NewRequest(http.MethodGet, "/", nil), jwtClaims) + require.Equal(t, c.want, got, "want %v, got %v, roleNamespace: %v, roleNames %+v", c.want, got, c.roleNamespace, c.roleNames) + } +} + +func generateSelfSignedRSACert(t *gotesting.T) (tlsCert tls.Certificate, certPEM []byte, keyPEM []byte) { + t.Helper() + + // Create a private key + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // Create a certificate template + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) // 1 year + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Organization"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + // Create the certificate + certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + // PEM encode the certificate and private key + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + // Load the certificate and key as tls.Certificate + tlsCert, err = tls.X509KeyPair(certPEM, keyPEM) + require.NoError(t, err) + + return tlsCert, certPEM, keyPEM +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..64c4366 --- /dev/null +++ b/config.go @@ -0,0 +1,296 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "fmt" + "net/url" + "time" + + "github.com/acronis/go-appkit/config" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +const ( + cfgKeyHTTPClientRequestTimeout = "auth.httpClient.requestTimeout" + cfgKeyGRPCClientRequestTimeout = "auth.grpcClient.requestTimeout" + cfgKeyJWTTrustedIssuers = "auth.jwt.trustedIssuers" + cfgKeyJWTTrustedIssuerURLs = "auth.jwt.trustedIssuerUrls" + cfgKeyJWTRequireAudience = "auth.jwt.requireAudience" + cfgKeyJWTExceptedAudience = "auth.jwt.expectedAudience" + cfgKeyJWTClaimsCacheEnabled = "auth.jwt.claimsCache.enabled" + cfgKeyJWTClaimsCacheMaxEntries = "auth.jwt.claimsCache.maxEntries" + cfgKeyJWKSCacheUpdateMinInterval = "auth.jwks.cache.updateMinInterval" + cfgKeyIntrospectionEnabled = "auth.introspection.enabled" + cfgKeyIntrospectionEndpoint = "auth.introspection.endpoint" + cfgKeyIntrospectionGRPCTarget = "auth.introspection.grpc.target" + cfgKeyIntrospectionGRPCTLSEnabled = "auth.introspection.grpc.tls.enabled" + cfgKeyIntrospectionGRPCTLSCACert = "auth.introspection.grpc.tls.caCert" + cfgKeyIntrospectionGRPCTLSClientCert = "auth.introspection.grpc.tls.clientCert" + cfgKeyIntrospectionGRPCTLSClientKey = "auth.introspection.grpc.tls.clientKey" + cfgKeyIntrospectionAccessTokenScope = "auth.introspection.accessTokenScope" // nolint:gosec // false positive + cfgKeyIntrospectionMinJWTVer = "auth.introspection.minJWTVersion" + cfgKeyIntrospectionClaimsCacheEnabled = "auth.introspection.claimsCache.enabled" + cfgKeyIntrospectionClaimsCacheMaxEntries = "auth.introspection.claimsCache.maxEntries" + cfgKeyIntrospectionClaimsCacheTTL = "auth.introspection.claimsCache.ttl" + cfgKeyIntrospectionNegativeCacheEnabled = "auth.introspection.negativeCache.enabled" + cfgKeyIntrospectionNegativeCacheMaxEntries = "auth.introspection.negativeCache.maxEntries" + cfgKeyIntrospectionNegativeCacheTTL = "auth.introspection.negativeCache.ttl" +) + +// JWTConfig is configuration of how JWT will be verified. +type JWTConfig struct { + TrustedIssuers map[string]string + TrustedIssuerURLs []string + RequireAudience bool + ExpectedAudience []string + ClaimsCache ClaimsCacheConfig +} + +// JWKSConfig is configuration of how JWKS will be used. +type JWKSConfig struct { + Cache struct { + UpdateMinInterval time.Duration + } +} + +// IntrospectionConfig is a configuration of how token introspection will be used. +type IntrospectionConfig struct { + Enabled bool + + Endpoint string + AccessTokenScope []string + + // MinJWTVersion is a minimum version of JWT that will be accepted for introspection. + // NOTE: it's a temporary solution for determining whether introspection is needed or not, + // and it will be removed in the future. + MinJWTVersion int + + ClaimsCache IntrospectionCacheConfig + NegativeCache IntrospectionCacheConfig + + GRPC IntrospectionGRPCConfig +} + +// ClaimsCacheConfig is a configuration of how claims cache will be used. +type ClaimsCacheConfig struct { + Enabled bool + MaxEntries int +} + +// IntrospectionCacheConfig is a configuration of how claims cache will be used for introspection. +type IntrospectionCacheConfig struct { + Enabled bool + MaxEntries int + TTL time.Duration +} + +// IntrospectionGRPCConfig is a configuration of how token will be introspected via gRPC. +type IntrospectionGRPCConfig struct { + Target string + RequestTimeout time.Duration + TLS GRPCTLSConfig +} + +// GRPCTLSConfig is a configuration of how gRPC connection will be secured. +type GRPCTLSConfig struct { + Enabled bool + CACert string + ClientCert string + ClientKey string +} + +type HTTPClientConfig struct { + RequestTimeout time.Duration +} + +type GRPCClientConfig struct { + RequestTimeout time.Duration +} + +// Config represents a set of configuration parameters for authentication and authorization. +type Config struct { + HTTPClient HTTPClientConfig + GRPCClient GRPCClientConfig + + JWT JWTConfig + JWKS JWKSConfig + Introspection IntrospectionConfig + + keyPrefix string +} + +var _ config.Config = (*Config)(nil) +var _ config.KeyPrefixProvider = (*Config)(nil) + +// NewConfig creates a new instance of the Config. +func NewConfig() *Config { + return NewConfigWithKeyPrefix("") +} + +// NewConfigWithKeyPrefix creates a new instance of the Config. +// Allows specifying key prefix which will be used for parsing configuration parameters. +func NewConfigWithKeyPrefix(keyPrefix string) *Config { + return &Config{keyPrefix: keyPrefix} +} + +// KeyPrefix returns a key prefix with which all configuration parameters should be presented. +func (c *Config) KeyPrefix() string { + return c.keyPrefix +} + +// SetProviderDefaults sets default configuration values for auth in config.DataProvider. +func (c *Config) SetProviderDefaults(dp config.DataProvider) { + dp.SetDefault(cfgKeyHTTPClientRequestTimeout, DefaultHTTPClientRequestTimeout.String()) + dp.SetDefault(cfgKeyGRPCClientRequestTimeout, DefaultGRPCClientRequestTimeout.String()) + dp.SetDefault(cfgKeyJWTClaimsCacheMaxEntries, jwt.DefaultClaimsCacheMaxEntries) + dp.SetDefault(cfgKeyJWKSCacheUpdateMinInterval, jwks.DefaultCacheUpdateMinInterval.String()) + dp.SetDefault(cfgKeyIntrospectionMinJWTVer, idptoken.MinJWTVersionForIntrospection) + dp.SetDefault(cfgKeyIntrospectionClaimsCacheMaxEntries, idptoken.DefaultIntrospectionClaimsCacheMaxEntries) + dp.SetDefault(cfgKeyIntrospectionClaimsCacheTTL, idptoken.DefaultIntrospectionClaimsCacheTTL.String()) + dp.SetDefault(cfgKeyIntrospectionNegativeCacheMaxEntries, idptoken.DefaultIntrospectionNegativeCacheMaxEntries) + dp.SetDefault(cfgKeyIntrospectionNegativeCacheTTL, idptoken.DefaultIntrospectionNegativeCacheTTL.String()) +} + +// Set sets auth configuration values from config.DataProvider. +func (c *Config) Set(dp config.DataProvider) error { + var err error + + if c.HTTPClient.RequestTimeout, err = dp.GetDuration(cfgKeyHTTPClientRequestTimeout); err != nil { + return err + } + if c.GRPCClient.RequestTimeout, err = dp.GetDuration(cfgKeyGRPCClientRequestTimeout); err != nil { + return err + } + if err = c.setJWTConfig(dp); err != nil { + return err + } + if err = c.setJWKSConfig(dp); err != nil { + return err + } + if err = c.setIntrospectionConfig(dp); err != nil { + return err + } + + return nil +} + +func (c *Config) setJWTConfig(dp config.DataProvider) error { + var err error + + if c.JWT.TrustedIssuers, err = dp.GetStringMapString(cfgKeyJWTTrustedIssuers); err != nil { + return err + } + if c.JWT.TrustedIssuerURLs, err = dp.GetStringSlice(cfgKeyJWTTrustedIssuerURLs); err != nil { + return err + } + for _, issURL := range c.JWT.TrustedIssuerURLs { + if _, err = url.Parse(issURL); err != nil { + return dp.WrapKeyErr(cfgKeyJWTTrustedIssuerURLs, err) + } + } + if c.JWT.RequireAudience, err = dp.GetBool(cfgKeyJWTRequireAudience); err != nil { + return err + } + if c.JWT.ExpectedAudience, err = dp.GetStringSlice(cfgKeyJWTExceptedAudience); err != nil { + return err + } + if c.JWT.ClaimsCache.Enabled, err = dp.GetBool(cfgKeyJWTClaimsCacheEnabled); err != nil { + return err + } + if c.JWT.ClaimsCache.MaxEntries, err = dp.GetInt(cfgKeyJWTClaimsCacheMaxEntries); err != nil { + return err + } + if c.JWT.ClaimsCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyJWTClaimsCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + + return nil +} + +func (c *Config) setJWKSConfig(dp config.DataProvider) error { + var err error + if c.JWKS.Cache.UpdateMinInterval, err = dp.GetDuration(cfgKeyJWKSCacheUpdateMinInterval); err != nil { + return err + } + return nil +} + +func (c *Config) setIntrospectionConfig(dp config.DataProvider) error { + var err error + + if c.Introspection.Enabled, err = dp.GetBool(cfgKeyIntrospectionEnabled); err != nil { + return err + } + if c.Introspection.Endpoint, err = dp.GetString(cfgKeyIntrospectionEndpoint); err != nil { + return err + } + if _, err = url.Parse(c.Introspection.Endpoint); err != nil { + return dp.WrapKeyErr(cfgKeyIntrospectionEndpoint, err) + } + + // GRPC + if c.Introspection.GRPC.Target, err = dp.GetString(cfgKeyIntrospectionGRPCTarget); err != nil { + return err + } + if c.Introspection.GRPC.TLS.Enabled, err = dp.GetBool(cfgKeyIntrospectionGRPCTLSEnabled); err != nil { + return err + } + if c.Introspection.GRPC.TLS.CACert, err = dp.GetString(cfgKeyIntrospectionGRPCTLSCACert); err != nil { + return err + } + if c.Introspection.GRPC.TLS.ClientCert, err = dp.GetString(cfgKeyIntrospectionGRPCTLSClientCert); err != nil { + return err + } + if c.Introspection.GRPC.TLS.ClientKey, err = dp.GetString(cfgKeyIntrospectionGRPCTLSClientKey); err != nil { + return err + } + + if c.Introspection.AccessTokenScope, err = dp.GetStringSlice(cfgKeyIntrospectionAccessTokenScope); err != nil { + return err + } + + if c.Introspection.MinJWTVersion, err = dp.GetInt(cfgKeyIntrospectionMinJWTVer); err != nil { + return err + } + if c.Introspection.MinJWTVersion < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionMinJWTVer, fmt.Errorf("minimum JWT version should be non-negative")) + } + + // Claims cache + if c.Introspection.ClaimsCache.Enabled, err = dp.GetBool(cfgKeyIntrospectionClaimsCacheEnabled); err != nil { + return err + } + if c.Introspection.ClaimsCache.MaxEntries, err = dp.GetInt(cfgKeyIntrospectionClaimsCacheMaxEntries); err != nil { + return err + } + if c.Introspection.ClaimsCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionClaimsCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + if c.Introspection.ClaimsCache.TTL, err = dp.GetDuration(cfgKeyIntrospectionClaimsCacheTTL); err != nil { + return err + } + + // Negative cache + if c.Introspection.NegativeCache.Enabled, err = dp.GetBool(cfgKeyIntrospectionNegativeCacheEnabled); err != nil { + return err + } + if c.Introspection.NegativeCache.MaxEntries, err = dp.GetInt(cfgKeyIntrospectionNegativeCacheMaxEntries); err != nil { + return err + } + if c.Introspection.NegativeCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionNegativeCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + if c.Introspection.NegativeCache.TTL, err = dp.GetDuration(cfgKeyIntrospectionNegativeCacheTTL); err != nil { + return err + } + + return nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..81809cd --- /dev/null +++ b/config_test.go @@ -0,0 +1,265 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/acronis/go-appkit/config" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/jwt" +) + +func TestConfig_Set(t *testing.T) { + t.Run("ok", func(t *testing.T) { + cfgData := bytes.NewBufferString(` +auth: + httpClient: + requestTimeout: 1m + grpcClient: + requestTimeout: 2m + jwt: + trustedIssuers: + my-issuer1: https://my-issuer1.com/idp + my-issuer2: https://my-issuer2.com/idp + trustedIssuerUrls: + - https://*.my-company1.com/idp + - https://*.my-company2.com/idp + requireAudience: true + expectedAudience: + - https://*.my-company1.com + - https://*.my-company2.com + jwks: + httpclient: + timeout: 1m + cache: + updateMinInterval: 5m + introspection: + enabled: true + endpoint: https://my-idp.com/introspect + claimsCache: + enabled: true + maxEntries: 42000 + ttl: 42s + negativeCache: + enabled: true + maxEntries: 777 + ttl: 77s + accessTokenScope: + - token_introspector + minJWTVersion: 3 + grpc: + target: "127.0.0.1:1234" + tls: + enabled: true + caCert: ca-cert.pem + clientCert: client-cert.pem + clientKey: client-key.pem +`) + cfg := Config{} + err := config.NewDefaultLoader("").LoadFromReader(cfgData, config.DataTypeYAML, &cfg) + require.NoError(t, err) + require.Equal(t, time.Minute*1, cfg.HTTPClient.RequestTimeout) + require.Equal(t, time.Minute*2, cfg.GRPCClient.RequestTimeout) + require.Equal(t, cfg.JWT, JWTConfig{ + TrustedIssuers: map[string]string{ + "my-issuer1": "https://my-issuer1.com/idp", + "my-issuer2": "https://my-issuer2.com/idp", + }, + TrustedIssuerURLs: []string{ + "https://*.my-company1.com/idp", + "https://*.my-company2.com/idp", + }, + RequireAudience: true, + ExpectedAudience: []string{ + "https://*.my-company1.com", + "https://*.my-company2.com", + }, + ClaimsCache: ClaimsCacheConfig{ + MaxEntries: jwt.DefaultClaimsCacheMaxEntries, + }, + }) + require.Equal(t, time.Minute*5, cfg.JWKS.Cache.UpdateMinInterval) + require.Equal(t, cfg.Introspection, IntrospectionConfig{ + Enabled: true, + Endpoint: "https://my-idp.com/introspect", + ClaimsCache: IntrospectionCacheConfig{ + Enabled: true, + MaxEntries: 42000, + TTL: time.Second * 42, + }, + NegativeCache: IntrospectionCacheConfig{ + Enabled: true, + MaxEntries: 777, + TTL: time.Second * 77, + }, + AccessTokenScope: []string{"token_introspector"}, + MinJWTVersion: 3, + GRPC: IntrospectionGRPCConfig{ + Target: "127.0.0.1:1234", + TLS: GRPCTLSConfig{ + Enabled: true, + CACert: "ca-cert.pem", + ClientCert: "client-cert.pem", + ClientKey: "client-key.pem", + }, + }, + }) + }) +} + +func TestConfig_SetErrors(t *testing.T) { + tests := []struct { + name string + cfgData string + errKey string + errMsg string + }{ + { + name: "invalid trusted issuer URL", + cfgData: ` +auth: + jwt: + trustedIssuerURLs: + - ://invalid-url +`, + errKey: cfgKeyJWTTrustedIssuerURLs, + errMsg: "missing protocol scheme", + }, + { + name: "negative claims cache max entries", + cfgData: ` +auth: + jwt: + claimsCache: + maxEntries: -1 +`, + errKey: cfgKeyJWTClaimsCacheMaxEntries, + errMsg: "max entries should be non-negative", + }, + { + name: "invalid HTTP client timeout", + cfgData: ` +auth: + httpClient: + requestTimeout: invalid +`, + errKey: cfgKeyHTTPClientRequestTimeout, + errMsg: "invalid duration", + }, + { + name: "invalid gRPC client timeout", + cfgData: ` +auth: + grpcClient: + requestTimeout: invalid +`, + errKey: cfgKeyGRPCClientRequestTimeout, + errMsg: "invalid duration", + }, + { + name: "invalid cache update min interval", + cfgData: ` +auth: + jwks: + cache: + updateMinInterval: invalid +`, + errKey: cfgKeyJWKSCacheUpdateMinInterval, + errMsg: "invalid duration", + }, + { + name: "invalid introspection endpoint URL", + cfgData: ` +auth: + introspection: + endpoint: ://invalid-url +`, + errKey: cfgKeyIntrospectionEndpoint, + errMsg: "missing protocol scheme", + }, + { + name: "negative introspection claims cache max entries", + cfgData: ` +auth: + introspection: + claimsCache: + maxEntries: -1 +`, + errKey: cfgKeyIntrospectionClaimsCacheMaxEntries, + errMsg: "max entries should be non-negative", + }, + { + name: "negative introspection negative cache max entries", + cfgData: ` +auth: + introspection: + negativeCache: + maxEntries: -1 +`, + errKey: cfgKeyIntrospectionNegativeCacheMaxEntries, + errMsg: "max entries should be non-negative", + }, + { + name: "invalid introspection claims cache TTL", + cfgData: ` +auth: + introspection: + claimsCache: + ttl: invalid +`, + errKey: cfgKeyIntrospectionClaimsCacheTTL, + errMsg: "invalid duration", + }, + { + name: "invalid introspection negative cache TTL", + cfgData: ` +auth: + introspection: + negativeCache: + ttl: invalid +`, + errKey: cfgKeyIntrospectionNegativeCacheTTL, + errMsg: "invalid duration", + }, + { + name: "invalid introspection access token scope", + cfgData: ` +auth: + introspection: + accessTokenScope: {} +`, + errKey: cfgKeyIntrospectionAccessTokenScope, + errMsg: " unable to cast", + }, + { + name: "negative introspection min JWT version", + cfgData: ` +auth: + introspection: + minJWTVersion: -1 +`, + errKey: cfgKeyIntrospectionMinJWTVer, + errMsg: "minimum JWT version should be non-negative", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgData := bytes.NewBufferString(tt.cfgData) + cfg := Config{} + err := config.NewDefaultLoader("").LoadFromReader(cfgData, config.DataTypeYAML, &cfg) + require.ErrorContains(t, err, tt.errMsg) + require.Truef(t, strings.HasPrefix(err.Error(), tt.errKey), + "expected error starts with %q, got %q", tt.errKey, err.Error()) + }) + } +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..fc4062e --- /dev/null +++ b/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package auth provides high-level helpers and basic objects for authN/authZ. +package auth diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..f4e7405 --- /dev/null +++ b/example_test.go @@ -0,0 +1,196 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwt" +) + +func ExampleJWTAuthMiddleware() { + jwtConfig := JWTConfig{ + TrustedIssuerURLs: []string{"https://my-idp.com"}, + //TrustedIssuers: map[string]string{"my-idp": "https://my-idp.com"}, // Use TrustedIssuers if you have a custom issuer name. + } + jwtParser, _ := NewJWTParser(&Config{JWT: jwtConfig}) + authN := JWTAuthMiddleware("MyService", jwtParser) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authN(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + //jwtClaims := GetJWTClaimsFromContext(r.Context()) // GetJWTClaimsFromContext is a helper function to get JWT claims from context. + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + done := make(chan struct{}) + server := &http.Server{Addr: ":8080", Handler: srvMux} + go func() { + defer close(done) + _ = server.ListenAndServe() + }() + + time.Sleep(time.Second) // Wait for the server to start. + + client := &http.Client{Timeout: time.Second * 30} + + fmt.Println("GET http://localhost:8080/") + resp, _ := client.Get("http://localhost:8080/") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin without token") + resp, _ = client.Get("http://localhost:8080/admin") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin with invalid token") + req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/admin", http.NoBody) + req.Header["Authorization"] = []string{"Bearer invalid-token"} + resp, _ = client.Do(req) + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + _ = server.Shutdown(context.Background()) + <-done + + // Output: + // GET http://localhost:8080/ + // Status code: 200 + // Body: Hello, World! + // ------ + // GET http://localhost:8080/admin without token + // Status code: 401 + // Body: {"error":{"domain":"MyService","code":"bearerTokenMissing","message":"Authorization bearer token is missing."}} + // ------ + // GET http://localhost:8080/admin with invalid token + // Status code: 401 + // Body: {"error":{"domain":"MyService","code":"authenticationFailed","message":"Authentication is failed."}} +} + +func ExampleJWTAuthMiddlewareWithVerifyAccess() { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + roUserClaims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "my-idp", + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(2 * time.Hour)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "read-only-user"}}, + } + roUserToken := idptest.MustMakeTokenStringSignedWithTestKey(roUserClaims) + + adminClaims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "my-idp", + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(2 * time.Hour)), + }, + Scope: []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}}, + } + adminToken := idptest.MustMakeTokenStringSignedWithTestKey(adminClaims) + + jwtConfig := JWTConfig{TrustedIssuers: map[string]string{"my-idp": issuerConfigServer.URL}} + jwtParser, _ := NewJWTParser(&Config{JWT: jwtConfig}) + authOnlyAdmin := JWTAuthMiddleware("MyService", jwtParser, + WithJWTAuthMiddlewareVerifyAccess(NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}))) + + srvMux := http.NewServeMux() + srvMux.Handle("/", http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, _ = rw.Write([]byte("Hello, World!")) + })) + srvMux.Handle("/admin", authOnlyAdmin(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, _ = rw.Write([]byte("Hello, admin!")) + }))) + + done := make(chan struct{}) + server := &http.Server{Addr: ":8080", Handler: srvMux} + go func() { + defer close(done) + _ = server.ListenAndServe() + }() + + time.Sleep(time.Second) // Wait for the server to start. + + client := &http.Client{Timeout: time.Second * 30} + + fmt.Println("GET http://localhost:8080/") + resp, _ := client.Get("http://localhost:8080/") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin without token") + resp, _ = client.Get("http://localhost:8080/admin") + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin with token of read-only user") + req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/admin", http.NoBody) + req.Header["Authorization"] = []string{"Bearer " + roUserToken} + resp, _ = client.Do(req) + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + fmt.Println("------") + fmt.Println("GET http://localhost:8080/admin with token of admin user") + req, _ = http.NewRequest(http.MethodGet, "http://localhost:8080/admin", http.NoBody) + req.Header["Authorization"] = []string{"Bearer " + adminToken} + resp, _ = client.Do(req) + fmt.Println("Status code:", resp.StatusCode) + respBody, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + fmt.Println("Body:", string(respBody)) + + _ = server.Shutdown(context.Background()) + <-done + + // Output: + // GET http://localhost:8080/ + // Status code: 200 + // Body: Hello, World! + // ------ + // GET http://localhost:8080/admin without token + // Status code: 401 + // Body: {"error":{"domain":"MyService","code":"bearerTokenMissing","message":"Authorization bearer token is missing."}} + // ------ + // GET http://localhost:8080/admin with token of read-only user + // Status code: 403 + // Body: {"error":{"domain":"MyService","code":"authorizationFailed","message":"Authorization is failed."}} + // ------ + // GET http://localhost:8080/admin with token of admin user + // Status code: 200 + // Body: Hello, admin! +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d710935 --- /dev/null +++ b/go.mod @@ -0,0 +1,58 @@ +module github.com/acronis/go-authkit + +go 1.20 + +require ( + github.com/acronis/go-appkit v1.3.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/uuid v1.6.0 + github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103 + github.com/prometheus/client_golang v1.20.4 + github.com/stretchr/testify v1.9.0 + github.com/vasayxtx/go-glob v1.2.0 + golang.org/x/sync v0.8.0 + google.golang.org/grpc v1.64.1 + google.golang.org/protobuf v1.34.2 +) + +require ( + code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9 // indirect + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/golang-lru v1.0.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/rs/xid v1.5.0 // indirect + github.com/sagikazarmark/locafero v0.6.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.7.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/viper v1.19.0 // indirect + github.com/ssgreg/logf v1.4.2 // indirect + github.com/ssgreg/logftext v1.1.1 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/throttled/throttled/v2 v2.12.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/time v0.6.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8a2a365 --- /dev/null +++ b/go.sum @@ -0,0 +1,221 @@ +code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9 h1:8KlrGCtoaWaaxVxi9KzED38kNIWa1qafh9bNSVZ6otk= +code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9/go.mod h1:eF2ZbltNI7Pv+8Cuyeksu9up5FN5konuH0trDJBuscw= +github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b h1:5/++qT1/z812ZqBvqQt6ToRswSuPZ/B33m6xVHRzADU= +github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b/go.mod h1:4+EPqMRApwwE/6yo6CxiHoSnBzjRr3jsqer7frxP8y4= +github.com/acronis/go-appkit v1.3.0 h1:IaX0DbD7HWp8ykqnK9F+c8757AmP4uBHVBe8J0Wv2sw= +github.com/acronis/go-appkit v1.3.0/go.mod h1:ouqWNe1/69fwjhx+2vV81Y6iqstfDzhmC6HpZ0E/gp4= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= +github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-redis/redis v6.15.8+incompatible h1:BKZuG6mCnRj5AOaWJXoCgf6rqTYnYJLe4en2hxT7r9o= +github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v8 v8.4.2/go.mod h1:A1tbYoHSa1fXwN+//ljcCYYJeLmVrwL9hbQN45Jdy0M= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= +github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= +github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103 h1:Z/i1e+gTZrmcGeZyWckaLfucYG6KYOXLWo4co8pZYNY= +github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103/go.mod h1:o9YPB5aGP8ob35Vy6+vyq3P3bWe7NQWzf+JLiXCiMaE= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= +github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo/v2 v2.20.0 h1:PE84V2mHqoT1sglvHc8ZdQtPcwmvvt29WLEEO3xmdZw= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk= +github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/ssgreg/logf v1.3.1/go.mod h1:s7bKemHNzeAi8OePMgR93dqfL4Swro4W3B2jSIyypl4= +github.com/ssgreg/logf v1.4.2 h1:J5qO5lVhFuHboQjYyTNt+0HlQifAYaLHgZLBxpDNIQQ= +github.com/ssgreg/logf v1.4.2/go.mod h1:s7bKemHNzeAi8OePMgR93dqfL4Swro4W3B2jSIyypl4= +github.com/ssgreg/logftext v1.1.1 h1:vq03mtTnUhmnznKwMeoW+mZrH8HxXUTzGDsxL6I6YMo= +github.com/ssgreg/logftext v1.1.1/go.mod h1:ONi7K7Bilp5Amyhq3sdVQ1lzzp4n4TyyVP4vpx962mA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/throttled/throttled/v2 v2.12.0 h1:IezKE1uHlYC/0Al05oZV6Ar+uN/znw3cy9J8banxhEY= +github.com/throttled/throttled/v2 v2.12.0/go.mod h1:+EAvrG2hZAQTx8oMpBu8fq6Xmm+d1P2luKK7fIY1Esc= +github.com/vasayxtx/go-glob v1.2.0 h1:t+/v9ROAeUVD2OLMcoS7yF6ojqaXSSRInAJ0vWOTU1g= +github.com/vasayxtx/go-glob v1.2.0/go.mod h1:wEj3yNgEm7emEVHCleh9WlNRoW9r3OsajUFgPvSLle0= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/otel v0.14.0/go.mod h1:vH5xEuwy7Rts0GNtsCW3HYQoZDY+OmBJ6t1bFGGlxgw= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211111213525-f221eed1c01e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/idptest/doc.go b/idptest/doc.go new file mode 100644 index 0000000..518a8d4 --- /dev/null +++ b/idptest/doc.go @@ -0,0 +1,10 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package idptest provides helper primitives and functions required for +// testing signing and key generation and a simple HTTP server +// with JWKS, issuer and IDP configuration endpoints. +package idptest diff --git a/idptest/grpc_server.go b/idptest/grpc_server.go new file mode 100644 index 0000000..35884af --- /dev/null +++ b/idptest/grpc_server.go @@ -0,0 +1,124 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/testutil" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" + + "github.com/acronis/go-authkit/idptoken/pb" +) + +type GRPCTokenCreator interface { + CreateToken(ctx context.Context, req *pb.CreateTokenRequest) (*pb.CreateTokenResponse, error) +} + +type GRPCTokenIntrospector interface { + IntrospectToken(ctx context.Context, req *pb.IntrospectTokenRequest) (*pb.IntrospectTokenResponse, error) +} + +type GRPCServer struct { + pb.UnimplementedIDPTokenServiceServer + *grpc.Server + addr atomic.Value + serverOpts []grpc.ServerOption + tokenIntrospector GRPCTokenIntrospector + tokenCreator GRPCTokenCreator +} + +type GRPCServerOption func(*GRPCServer) + +func WithGRPCAddr(addr string) GRPCServerOption { + return func(server *GRPCServer) { + server.addr.Store(addr) + } +} + +func WithGRPCServerOptions(opts ...grpc.ServerOption) GRPCServerOption { + return func(s *GRPCServer) { + s.serverOpts = opts + } +} + +func WithGRPCTokenIntrospector(tokenIntrospector GRPCTokenIntrospector) GRPCServerOption { + return func(s *GRPCServer) { + s.tokenIntrospector = tokenIntrospector + } +} + +func WithGRPCTokenCreator(tokenCreator GRPCTokenCreator) GRPCServerOption { + return func(s *GRPCServer) { + s.tokenCreator = tokenCreator + } +} + +// NewGRPCServer creates a new instance of GRPCServer. +func NewGRPCServer( + opts ...GRPCServerOption, +) *GRPCServer { + srv := &GRPCServer{} + for _, opt := range opts { + opt(srv) + } + srv.Server = grpc.NewServer(srv.serverOpts...) + pb.RegisterIDPTokenServiceServer(srv.Server, srv) + reflection.Register(srv.Server) + return srv +} + +// Addr returns the server address. +func (s *GRPCServer) Addr() string { + return s.addr.Load().(string) +} + +// Start starts the GRPC server +func (s *GRPCServer) Start() error { + addr, ok := s.addr.Load().(string) + if !ok { + addr = localhostWithDynamicPortAddr + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("listen tcp: %w", err) + } + s.addr.Store(ln.Addr().String()) + + go func() { _ = s.Serve(ln) }() + + return nil +} + +// StartAndWaitForReady starts the server waits for the server to start listening. +func (s *GRPCServer) StartAndWaitForReady(timeout time.Duration) error { + if err := s.Start(); err != nil { + return fmt.Errorf("start server: %w", err) + } + return testutil.WaitListeningServer(s.Addr(), timeout) +} + +func (s *GRPCServer) CreateToken(ctx context.Context, req *pb.CreateTokenRequest) (*pb.CreateTokenResponse, error) { + if s.tokenCreator != nil { + return s.tokenCreator.CreateToken(ctx, req) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateToken not implemented") +} + +func (s *GRPCServer) IntrospectToken(ctx context.Context, req *pb.IntrospectTokenRequest) (*pb.IntrospectTokenResponse, error) { + if s.tokenIntrospector != nil { + return s.tokenIntrospector.IntrospectToken(ctx, req) + } + return nil, status.Errorf(codes.Unimplemented, "method IntrospectToken not implemented") +} diff --git a/idptest/http_server.go b/idptest/http_server.go new file mode 100644 index 0000000..6f71464 --- /dev/null +++ b/idptest/http_server.go @@ -0,0 +1,175 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "fmt" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/testutil" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" +) + +const ( + OpenIDConfigurationPath = "/.well-known/openid-configuration" + JWKSEndpointPath = "/idp/keys" + TokenEndpointPath = "/idp/token" + TokenIntrospectionEndpointPath = "/idp/introspect_token" // nolint:gosec // This server is used for testing purposes only. +) + +const localhostWithDynamicPortAddr = "127.0.0.1:0" + +// HTTPClaimsProvider is an interface for providing JWT claims in HTTP handlers. +type HTTPClaimsProvider interface { + Provide(r *http.Request) jwt.Claims +} + +// HTTPTokenIntrospector is an interface for introspecting tokens. +type HTTPTokenIntrospector interface { + IntrospectToken(r *http.Request, token string) idptoken.IntrospectionResult +} + +type HTTPServerOption func(s *HTTPServer) + +// WithHTTPAddress is an option to set HTTP server address. +func WithHTTPAddress(addr string) HTTPServerOption { + return func(s *HTTPServer) { + s.addr.Store(addr) + } +} + +// WithHTTPOpenIDConfigurationHandler is an option to set custom handler for GET /.well-known/openid-configuration. +// Otherwise, OpenIDConfigurationHandler will be used. +func WithHTTPOpenIDConfigurationHandler(handler http.HandlerFunc) HTTPServerOption { + return func(s *HTTPServer) { + s.OpenIDConfigurationHandler = handler + } +} + +// WithHTTPKeysHandler is an option to set custom handler for GET /idp/keys. +// Otherwise, JWKSHandler will be used. +func WithHTTPKeysHandler(handler http.Handler) HTTPServerOption { + return func(s *HTTPServer) { + s.KeysHandler = handler + } +} + +// WithHTTPPublicJWKS is an option to set public JWKS for JWKSHandler which will be used for GET /idp/keys. +func WithHTTPPublicJWKS(keys []PublicJWK) HTTPServerOption { + return func(s *HTTPServer) { + s.KeysHandler = &JWKSHandler{PublicJWKS: keys} + } +} + +// WithHTTPTokenHandler is an option to set custom handler for POST /idp/token. +func WithHTTPTokenHandler(handler http.Handler) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenHandler = handler + } +} + +// WithHTTPClaimsProvider is an option to set ClaimsProvider for TokenHandler +// which will be used for POST /idp/token. +func WithHTTPClaimsProvider(claimsProvider HTTPClaimsProvider) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenHandler = &TokenHandler{ClaimsProvider: claimsProvider} + } +} + +// WithHTTPIntrospectTokenHandler is an option to set custom handler for POST /idp/introspect_token. +func WithHTTPIntrospectTokenHandler(handler http.Handler) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenIntrospectionHandler = handler + } +} + +// WithHTTPTokenIntrospector is an option to set TokenIntrospector for TokenIntrospectionHandler +// which will be used for POST /idp/introspect_token. +func WithHTTPTokenIntrospector(introspector HTTPTokenIntrospector) HTTPServerOption { + return func(s *HTTPServer) { + s.TokenIntrospectionHandler = &TokenIntrospectionHandler{TokenIntrospector: introspector} + } +} + +// HTTPServer is a mock IDP server for testing purposes. +type HTTPServer struct { + *http.Server + addr atomic.Value + KeysHandler http.Handler + TokenHandler http.Handler + TokenIntrospectionHandler http.Handler + OpenIDConfigurationHandler http.Handler + Router *http.ServeMux +} + +// NewHTTPServer creates a new IDPMockServer with provided options. +func NewHTTPServer(options ...HTTPServerOption) *HTTPServer { + s := &HTTPServer{ + Router: http.NewServeMux(), + TokenHandler: &TokenHandler{}, + KeysHandler: &JWKSHandler{}, + TokenIntrospectionHandler: &TokenIntrospectionHandler{}, + } + s.OpenIDConfigurationHandler = &OpenIDConfigurationHandler{ + BaseURLFunc: s.URL, + JWKSURL: JWKSEndpointPath, + TokenEndpointURL: TokenEndpointPath, + IntrospectionEndpointURL: TokenIntrospectionEndpointPath, + } + + for _, opt := range options { + opt(s) + } + + s.Router.Handle(OpenIDConfigurationPath, s.OpenIDConfigurationHandler) + s.Router.Handle(JWKSEndpointPath, s.KeysHandler) + s.Router.Handle(TokenEndpointPath, s.TokenHandler) + s.Router.Handle(TokenIntrospectionEndpointPath, s.TokenIntrospectionHandler) + + // nolint:gosec // This server is used for testing purposes only. + s.Server = &http.Server{Handler: s.Router} + + return s +} + +// URL method returns the URL of the server. +func (s *HTTPServer) URL() string { + if srvURL := s.addr.Load(); srvURL != nil { + return "http://" + srvURL.(string) + } + return "" +} + +// Start starts the HTTPServer. +func (s *HTTPServer) Start() error { + addr, ok := s.addr.Load().(string) + if !ok { + addr = localhostWithDynamicPortAddr + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("listen tcp: %w", err) + } + s.addr.Store(ln.Addr().String()) + + go func() { _ = s.Server.Serve(ln) }() + + return nil +} + +// StartAndWaitForReady starts the server waits for the server to start listening. +func (s *HTTPServer) StartAndWaitForReady(timeout time.Duration) error { + if err := s.Start(); err != nil { + return fmt.Errorf("start server: %w", err) + } + return testutil.WaitListeningServer(s.addr.Load().(string), timeout) +} diff --git a/idptest/jwks_handler.go b/idptest/jwks_handler.go new file mode 100644 index 0000000..32256b0 --- /dev/null +++ b/idptest/jwks_handler.go @@ -0,0 +1,100 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" +) + +// TestKeyID is a key ID of the pre-defined key for testing. +const TestKeyID = "fac01c070cd08ba08809762da6e4f74af14e4790" + +// TestPlainPrivateJWK is a plaintext representation of the pre-defined private key for testing. +// nolint: lll +const TestPlainPrivateJWK = ` +{ + "alg": "RS256", + "d": "U4iZRcf35HT68wCF4cXPvOCU-aHbHoxkb99okPIf_pyexcznb3AjJCS9PRW3MnR1UkcDJ509Cwq7HTZ6dSPO8bMagGlFR4PttNgYzRg793r7ZamhzjPy_Udr35a79z6Q3rLBzeyFljhZ708cgU-tYw_7KpPytPN9cvr9MuvHtRWZlYuFRIql0PiOkq9hLMz_rLGb2rEmlQ9Bxk0tAOct9et-k8qgqwUjd0APgyHRxBU3gKUcbwgIb4KjYzypVjAW8Y3eIN8DwJg8P9AsMWdyKLW36exN_jZXGq6HrdecqV5hOGRELQ3Ok4sn-XMEyYVu9urQkZIRHGsbvUdwXskYKQ", + "dp": "jzn3HUi8J7QJF0JIT1VRbX8ngf4c7EDpV76rjTjdDNgGQJF6RZ34DfZoSWJlnoS2aJl2LW_Z2dXUnSi2JzVK_joEmFtMCe6vEiYKl3_2Avw43wfSJ1Kj7CTAOlRiifzdP9RosoYgznLjKq9_WyKlFVy7jXN11f-SBjiwYGyn_FM", + "dq": "pV6go3sp23q7jJP0DMrY0fmrC3AhWTGB3hb9w_KKVDV_J1ljSCpkSB55FfHkCsiTFfBHieCThNl7iWgl0eZJ7TtsdVUSj_dZfAMfi49nj6GBa5mjEUMSGgtUrqWNWf31rKKXz4Y4o0A6U8N8FuX34ANCj9hEBE0UIzdV_e7L1AE", + "e": "AQAB", + "kid": "fac01c070cd08ba08809762da6e4f74af14e4790", + "kty": "RSA", + "n": "mWeDDhcnVdKWbYGubOB7v1rZ395noYk-MFV0Ik78nLsJc1Ni3-GaWpJOTfCFivDP6DcS68Q04olx6_CleaDWU2KHeZE9PuJcW1_Xot3w1U2WZYpzl5_E5jqHjq1-nnOfe5Mq5SbpoZi3o3-QjktiSgaZ6w-575anM-6VhfxyS0s_DKGJHzyka1hJIoGb8vBstKS6oVLcgjQO3JR_Uy4XMdO9s3z-t3_4sO7qtHuEmqFUnaUx5MuLmZnV0hWyLHoNtEQZrf6X5lcnSj-6QerRihJdQeFDm494D96UwjKt70xgbAMvY-H2RcCJ5IqB2jvumqACt70twX7VCeS8FDMP_w", + "p": "w7rqemF-CmOU2X4p_4yzZaVq5CYmq9f-d1QLfK9AdMhIAPAlGxIkevXq6dAnjLWLJ9ksuOFkjWpoNI40JyhPJqif8U8WFyDqMsAEFif4HYVh-iR3NMr489lExBqx-YmmYHJ-pXxpcQhwAIbUkS-iF4eIx44JwVPNniU97Djy_ws", + "q": "yKQfjhWZSFzsn1CveQS6X6H1GtIbpWW9WBR0TFyUWrDtBxe1ivv21ie9hMDhpwk9t9ONUXqt-nNDMtK558q_fGKzMDwYIztX5vXRW9MMR6A7gylSGVspsUbk-egE2dXpwaGwdwr1RvFHEjBNeJQWxvQH-g-QNhJQm6gBdzn6210", + "qi": "Yt33e8KxCstCfgD4MvPg-uTVj6o2f893zbast8b_yunEBZK-c4WnJ73Taj7lOB2iME97XrBsx3f-jdslt6xHd9h0mam_Fi53JxQDoiyPcLWfcgcMY2w4jjoY_-Iqtnnisf7tHGgrba9eyNHRl91oXFgoaduNmeUs1z_yF_GARJo", + "use": "sig" +} +` + +type PublicJWK struct { + Alg string `json:"alg"` + E string `json:"e"` + Kid string `json:"kid"` + Kty string `json:"kty"` + N string `json:"n"` + Use string `json:"use"` +} + +func GetTestPublicJWKS() []PublicJWK { + return []PublicJWK{ + { + Alg: "RS256", + E: "AQAB", + Kid: TestKeyID, + Kty: "RSA", + N: "mWeDDhcnVdKWbYGubOB7v1rZ395noYk-MFV0Ik78nLsJc1Ni3-GaWpJOTfCFivDP6DcS68Q04olx6_CleaDWU2KHeZE9PuJcW1_Xot3w1U2WZYpzl5_E5jqHjq1-nnOfe5Mq5SbpoZi3o3-QjktiSgaZ6w-575anM-6VhfxyS0s_DKGJHzyka1hJIoGb8vBstKS6oVLcgjQO3JR_Uy4XMdO9s3z-t3_4sO7qtHuEmqFUnaUx5MuLmZnV0hWyLHoNtEQZrf6X5lcnSj-6QerRihJdQeFDm494D96UwjKt70xgbAMvY-H2RcCJ5IqB2jvumqACt70twX7VCeS8FDMP_w", // nolint:lll + Use: "sig", + }, + { + Alg: "RS256", + E: "AQAB", + Kid: "737c5114f09b5ed05276bd4b520245982f7fb29f", + Kty: "RSA", + N: "51gGypRFvhTziiCLW3emsFx80G3ljpoYdDdieYM-yfvv6cfpkiEnxRRig5JdJ62vrENgbZi1GZpvTs3B7ly7Z4FI6EM-5e8vIkQSYuE3sXU7QsxEFjtMUm31kao4179gmIIrycHl5M1HE2FU2Ssgf7VuKIVmLvDypNHgBb8cV2XKu_PiGHk2turbKZXxegJTiMBYrgKSaEuBUi3WC3j-onHmQriThchQujmXVMFQ-5syNkUX7hM8PKKONkFUhKANnh0Om8_Sc3bcYZAIoFA2cD-PXopJUQa8GLRfWLExVHRvp-4_vtDYbEAeipPYz2cRmEoMKiLRk8ZpLI6M71ugLQ", // nolint:lll + Use: "sig", + }, + } +} + +// JWKSHandler is an HTTP handler that responds JWKS. +type JWKSHandler struct { + servedCount atomic.Uint64 + PublicJWKS []PublicJWK +} + +func (h *JWKSHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(rw, "Only GET method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + rw.Header().Set("Content-Type", "application/json") + publicJWKS := h.PublicJWKS + if len(publicJWKS) == 0 { + publicJWKS = GetTestPublicJWKS() + } + if err := json.NewEncoder(rw).Encode(PublicJWKSResponse{Keys: publicJWKS}); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +// ServedCount returns the number of times JWKS handler has been served. +func (h *JWKSHandler) ServedCount() uint64 { + return h.servedCount.Load() +} + +type PublicJWKSResponse struct { + Keys []PublicJWK `json:"keys"` +} diff --git a/idptest/jwt.go b/idptest/jwt.go new file mode 100644 index 0000000..172dee5 --- /dev/null +++ b/idptest/jwt.go @@ -0,0 +1,86 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "crypto" + "encoding/json" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/mendsley/gojwk" + + "github.com/acronis/go-authkit/idptoken" +) + +// SignToken signs token with key. +func SignToken(token *jwtgo.Token, rsaPrivateKey interface{}) (string, error) { + return token.SignedString(rsaPrivateKey) +} + +// MakeTokenStringWithHeader create test signed token with claims and headers. +func MakeTokenStringWithHeader( + claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, +) (string, error) { + token := jwtgo.NewWithClaims(jwtgo.SigningMethodRS256, claims) + token.Header["typ"] = idptoken.JWTTypeAccessToken + token.Header["kid"] = kid + for k, v := range header { + token.Header[k] = v + } + return SignToken(token, rsaPrivateKey) +} + +// MustMakeTokenStringWithHeader create test signed token with claims and headers. +// It panics if error occurs. +func MustMakeTokenStringWithHeader( + claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, +) string { + token, err := MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, header) + if err != nil { + panic(err) + } + return token +} + +// MakeTokenString create signed token with claims. +func MakeTokenString(claims jwtgo.Claims, kid string, rsaPrivateKey interface{}) (string, error) { + return MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, nil) +} + +// MustMakeTokenString create signed token with claims. +// It panics if error occurs. +func MustMakeTokenString(claims jwtgo.Claims, kid string, rsaPrivateKey interface{}) string { + token, err := MakeTokenStringWithHeader(claims, kid, rsaPrivateKey, nil) + if err != nil { + panic(err) + } + return token +} + +// GetTestRSAPrivateKey returns pre-defined RSA private key for testing. +func GetTestRSAPrivateKey() crypto.PrivateKey { + var privKey gojwk.Key + _ = json.Unmarshal([]byte(TestPlainPrivateJWK), &privKey) + rsaPrivKey, _ := privKey.DecodePrivateKey() + return rsaPrivKey +} + +// MakeTokenStringSignedWithTestKey create test token signed with the pre-defined private key (TestKeyID) for testing. +func MakeTokenStringSignedWithTestKey(claims jwtgo.Claims) (string, error) { + return MakeTokenStringWithHeader(claims, TestKeyID, GetTestRSAPrivateKey(), nil) +} + +// MustMakeTokenStringSignedWithTestKey create test token signed +// with the pre-defined private key (TestKeyID) for testing. +// It panics if error occurs. +func MustMakeTokenStringSignedWithTestKey(claims jwtgo.Claims) string { + token, err := MakeTokenStringSignedWithTestKey(claims) + if err != nil { + panic(err) + } + return token +} diff --git a/idptest/jwt_test.go b/idptest/jwt_test.go new file mode 100644 index 0000000..a8056e9 --- /dev/null +++ b/idptest/jwt_test.go @@ -0,0 +1,54 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +const testIss = "test-issuer" + +func TestMakeTokenStringWithHeader(t *testing.T) { + jwksServer := httptest.NewServer(&JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + logger := log.NewDisabledLogger() + + jwtClaims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{ + {ResourceNamespace: "policy_manager", Role: "admin"}, + }, + } + + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + parsedClaims, err := parser.Parse(context.Background(), MustMakeTokenStringSignedWithTestKey(jwtClaims)) + require.NoError(t, err) + require.Equal( + t, + []jwt.AccessPolicy{{ResourceNamespace: "policy_manager", Role: "admin"}}, + parsedClaims.Scope, + ) +} diff --git a/idptest/openid_configuration_handler.go b/idptest/openid_configuration_handler.go new file mode 100644 index 0000000..a790bc2 --- /dev/null +++ b/idptest/openid_configuration_handler.go @@ -0,0 +1,60 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" + + "github.com/acronis/go-authkit/internal/idputil" +) + +// OpenIDConfigurationHandler is an HTTP handler that responds token's issuer OpenID configuration. +type OpenIDConfigurationHandler struct { + servedCount atomic.Uint64 + BaseURLFunc func() string // for cases when 'host:port' of providers' addresses to be determined during runtime + JWKSURL string + TokenEndpointURL string + IntrospectionEndpointURL string +} + +func (h *OpenIDConfigurationHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(rw, "Only GET method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + openIDCfg := idputil.OpenIDConfiguration{ + TokenURL: h.makeEndpointURL(h.TokenEndpointURL, TokenIntrospectionEndpointPath), + IntrospectionEndpoint: h.makeEndpointURL(h.IntrospectionEndpointURL, TokenIntrospectionEndpointPath), + JWKSURI: h.makeEndpointURL(h.JWKSURL, JWKSEndpointPath), + } + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(openIDCfg); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +func (h *OpenIDConfigurationHandler) makeEndpointURL(endpointURL string, defaultPath string) string { + if endpointURL == "" { + endpointURL = defaultPath + } + if h.BaseURLFunc != nil { + endpointURL = h.BaseURLFunc() + endpointURL + } + return endpointURL +} + +// ServedCount returns the number of times the handler has been served. +func (h *OpenIDConfigurationHandler) ServedCount() uint64 { + return h.servedCount.Load() +} diff --git a/idptest/token_handlers.go b/idptest/token_handlers.go new file mode 100644 index 0000000..a260aa6 --- /dev/null +++ b/idptest/token_handlers.go @@ -0,0 +1,103 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptest + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" + "time" +) + +// TokenHandler is an implementation of a handler responding with IDP token. +type TokenHandler struct { + servedCount atomic.Uint64 + ClaimsProvider HTTPClaimsProvider +} + +func (h *TokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(rw, "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + if h.ClaimsProvider == nil { + http.Error(rw, "ClaimsProvider for TokenHandler is not configured", http.StatusInternalServerError) + return + } + + claims := h.ClaimsProvider.Provide(r) + + token, err := MakeTokenStringWithHeader(claims, TestKeyID, GetTestRSAPrivateKey(), nil) + if err != nil { + http.Error(rw, fmt.Sprintf("Claims provider failed generate token: %v", err), http.StatusInternalServerError) + return + } + + expiresIn := claims.ExpiresAt.Unix() - time.Now().UTC().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + response := struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + }{ + AccessToken: token, + ExpiresIn: expiresIn, + } + rw.Header().Set("Content-Type", "application/json") + if err = json.NewEncoder(rw).Encode(response); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +// ServedCount returns the number of times the handler has been served. +func (h *TokenHandler) ServedCount() uint64 { + return h.servedCount.Load() +} + +type TokenIntrospectionHandler struct { + servedCount atomic.Uint64 + TokenIntrospector HTTPTokenIntrospector +} + +func (h *TokenIntrospectionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(rw, "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + h.servedCount.Add(1) + + if h.TokenIntrospector == nil { + http.Error(rw, "HTTPTokenIntrospector for TokenIntrospectionHandler is not configured", http.StatusInternalServerError) + return + } + + token := r.FormValue("token") + if token == "" { + http.Error(rw, "Token is required", http.StatusBadRequest) + return + } + introspectResult := h.TokenIntrospector.IntrospectToken(r, token) + + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(introspectResult); err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +// ServedCount returns the number of times the handler has been served. +func (h *TokenIntrospectionHandler) ServedCount() uint64 { + return h.servedCount.Load() +} diff --git a/idptest/token_provider.go b/idptest/token_provider.go new file mode 100644 index 0000000..1a3a386 --- /dev/null +++ b/idptest/token_provider.go @@ -0,0 +1,26 @@ +package idptest + +import ( + "context" + "sync/atomic" +) + +type SimpleTokenProvider struct { + token atomic.Value +} + +func NewSimpleTokenProvider(token string) *SimpleTokenProvider { + tp := &SimpleTokenProvider{} + tp.SetToken(token) + return tp +} + +func (m *SimpleTokenProvider) GetToken(ctx context.Context, scope ...string) (string, error) { + return m.token.Load().(string), nil +} + +func (m *SimpleTokenProvider) Invalidate() {} + +func (m *SimpleTokenProvider) SetToken(token string) { + m.token.Store(token) +} diff --git a/idptoken/caching_introspector.go b/idptoken/caching_introspector.go new file mode 100644 index 0000000..e00a807 --- /dev/null +++ b/idptoken/caching_introspector.go @@ -0,0 +1,194 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + "unsafe" + + "github.com/acronis/go-appkit/lrucache" + + "github.com/acronis/go-authkit/jwt" +) + +const ( + DefaultIntrospectionClaimsCacheMaxEntries = 1000 + DefaultIntrospectionClaimsCacheTTL = 1 * time.Minute + DefaultIntrospectionNegativeCacheMaxEntries = 1000 + DefaultIntrospectionNegativeCacheTTL = 10 * time.Minute +) + +type IntrospectionClaimsCacheItem struct { + Claims *jwt.Claims + TokenType string + CreatedAt time.Time +} + +type IntrospectionClaimsCache interface { + Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionClaimsCacheItem, bool) + Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionClaimsCacheItem) + Purge(ctx context.Context) + Len(ctx context.Context) int +} + +type IntrospectionNegativeCacheItem struct { + CreatedAt time.Time +} + +type IntrospectionNegativeCache interface { + Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionNegativeCacheItem, bool) + Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionNegativeCacheItem) + Purge(ctx context.Context) + Len(ctx context.Context) int +} + +type CachingIntrospectorOpts struct { + IntrospectorOpts + ClaimsCache CachingIntrospectorCacheOpts + NegativeCache CachingIntrospectorCacheOpts +} + +type CachingIntrospectorCacheOpts struct { + Enabled bool + MaxEntries int + TTL time.Duration +} + +type CachingIntrospector struct { + *Introspector + ClaimsCache IntrospectionClaimsCache + NegativeCache IntrospectionNegativeCache + claimsCacheTTL time.Duration + negativeCacheTTL time.Duration +} + +func NewCachingIntrospector(tokenProvider IntrospectionTokenProvider) (*CachingIntrospector, error) { + return NewCachingIntrospectorWithOpts(tokenProvider, CachingIntrospectorOpts{}) +} + +func NewCachingIntrospectorWithOpts( + tokenProvider IntrospectionTokenProvider, opts CachingIntrospectorOpts, +) (*CachingIntrospector, error) { + if !opts.ClaimsCache.Enabled && !opts.NegativeCache.Enabled { + return nil, fmt.Errorf("at least one of claims or negative cache must be enabled") + } + + introspector := NewIntrospectorWithOpts(tokenProvider, opts.IntrospectorOpts) + + // Building claims cache if needed. + var claimsCache IntrospectionClaimsCache = &disabledIntrospectionClaimsCache{} + if opts.ClaimsCache.Enabled { + if opts.ClaimsCache.TTL == 0 { + opts.ClaimsCache.TTL = DefaultIntrospectionClaimsCacheTTL + } + if opts.ClaimsCache.MaxEntries == 0 { + opts.ClaimsCache.MaxEntries = DefaultIntrospectionClaimsCacheMaxEntries + } + cache, err := lrucache.New[[sha256.Size]byte, IntrospectionClaimsCacheItem]( + opts.ClaimsCache.MaxEntries, introspector.promMetrics.TokenClaimsCache) + if err != nil { + return nil, err + } + claimsCache = &introspectionCacheLRUAdapter[[sha256.Size]byte, IntrospectionClaimsCacheItem]{cache} + } + + // Building negative cache if needed. + var negativeCache IntrospectionNegativeCache = &disabledIntrospectionNegativeCache{} + if opts.NegativeCache.Enabled { + if opts.NegativeCache.TTL == 0 { + opts.NegativeCache.TTL = DefaultIntrospectionNegativeCacheTTL + } + if opts.NegativeCache.MaxEntries == 0 { + opts.NegativeCache.MaxEntries = DefaultIntrospectionNegativeCacheMaxEntries + } + cache, err := lrucache.New[[sha256.Size]byte, IntrospectionNegativeCacheItem]( + opts.NegativeCache.MaxEntries, introspector.promMetrics.TokenNegativeCache) + if err != nil { + return nil, err + } + negativeCache = &introspectionCacheLRUAdapter[[sha256.Size]byte, IntrospectionNegativeCacheItem]{cache} + } + + return &CachingIntrospector{ + Introspector: introspector, + ClaimsCache: claimsCache, + NegativeCache: negativeCache, + claimsCacheTTL: opts.ClaimsCache.TTL, + negativeCacheTTL: opts.NegativeCache.TTL, + }, nil +} + +func (i *CachingIntrospector) IntrospectToken(ctx context.Context, token string) (IntrospectionResult, error) { + cacheKey := sha256.Sum256( + unsafe.Slice(unsafe.StringData(token), len(token))) // nolint:gosec // prevent redundant slice copying + + if c, ok := i.ClaimsCache.Get(ctx, cacheKey); ok && c.CreatedAt.Add(i.claimsCacheTTL).After(time.Now()) { + return IntrospectionResult{Active: true, TokenType: c.TokenType, Claims: *c.Claims}, nil + } + if c, ok := i.NegativeCache.Get(ctx, cacheKey); ok && c.CreatedAt.Add(i.negativeCacheTTL).After(time.Now()) { + return IntrospectionResult{Active: false}, nil + } + + introspectionResult, err := i.Introspector.IntrospectToken(ctx, token) + if err != nil { + return IntrospectionResult{}, err + } + if introspectionResult.Active { + i.ClaimsCache.Add(ctx, cacheKey, IntrospectionClaimsCacheItem{ + Claims: &introspectionResult.Claims, + TokenType: introspectionResult.TokenType, + CreatedAt: time.Now(), + }) + } else { + i.NegativeCache.Add(ctx, cacheKey, IntrospectionNegativeCacheItem{CreatedAt: time.Now()}) + } + + return introspectionResult, nil +} + +type introspectionCacheLRUAdapter[K comparable, V any] struct { + cache *lrucache.LRUCache[K, V] +} + +func (a *introspectionCacheLRUAdapter[K, V]) Get(_ context.Context, key K) (V, bool) { + return a.cache.Get(key) +} + +func (a *introspectionCacheLRUAdapter[K, V]) Add(_ context.Context, key K, val V) { + a.cache.Add(key, val) +} + +func (a *introspectionCacheLRUAdapter[K, V]) Purge(ctx context.Context) { + a.cache.Purge() +} + +func (a *introspectionCacheLRUAdapter[K, V]) Len(ctx context.Context) int { + return a.cache.Len() +} + +type disabledIntrospectionClaimsCache struct{} + +func (c *disabledIntrospectionClaimsCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionClaimsCacheItem, bool) { + return IntrospectionClaimsCacheItem{}, false +} +func (c *disabledIntrospectionClaimsCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionClaimsCacheItem) { +} +func (c *disabledIntrospectionClaimsCache) Purge(ctx context.Context) {} +func (c *disabledIntrospectionClaimsCache) Len(ctx context.Context) int { return 0 } + +type disabledIntrospectionNegativeCache struct{} + +func (c *disabledIntrospectionNegativeCache) Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionNegativeCacheItem, bool) { + return IntrospectionNegativeCacheItem{}, false +} +func (c *disabledIntrospectionNegativeCache) Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionNegativeCacheItem) { +} +func (c *disabledIntrospectionNegativeCache) Purge(ctx context.Context) {} +func (c *disabledIntrospectionNegativeCache) Len(ctx context.Context) int { return 0 } diff --git a/idptoken/caching_introspector_test.go b/idptoken/caching_introspector_test.go new file mode 100644 index 0000000..b6575f7 --- /dev/null +++ b/idptoken/caching_introspector_test.go @@ -0,0 +1,255 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "context" + "net/http" + "net/url" + gotesting "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/testing" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +func TestCachingIntrospector_IntrospectToken(t *gotesting.T) { + serverIntrospector := testing.NewHTTPServerTokenIntrospectorMock() + + idpSrv := idptest.NewHTTPServer(idptest.WithHTTPTokenIntrospector(serverIntrospector)) + require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = idpSrv.Shutdown(context.Background()) }() + + const accessToken = "access-token-with-introspection-permission" + tokenProvider := idptest.NewSimpleTokenProvider(accessToken) + + logger := log.NewDisabledLogger() + jwtParser := jwt.NewParser(jwks.NewClient(http.DefaultClient, logger), logger) + require.NoError(t, jwtParser.AddTrustedIssuerURL(idpSrv.URL())) + serverIntrospector.JWTParser = jwtParser + + jwtExpiresAtInFuture := jwtgo.NewNumericDate(time.Now().Add(time.Hour)) + jwtIssuer := idpSrv.URL() + jwtSubject := uuid.NewString() + jwtID := uuid.NewString() + jwtScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-" + uuid.NewString(), + }} + + expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: idpSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + }, + }) + activeJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + }) + + opaqueToken1 := "opaque-token-" + uuid.NewString() + opaqueToken2 := "opaque-token-" + uuid.NewString() + opaqueToken3 := "opaque-token-" + uuid.NewString() + opaqueToken1Scope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + opaqueToken2Scope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "event-manager", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + + serverIntrospector.SetScopeForJWTID(jwtID, jwtScope) + serverIntrospector.SetResultForToken(opaqueToken1, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}) + serverIntrospector.SetResultForToken(opaqueToken2, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}) + serverIntrospector.SetResultForToken(opaqueToken3, idptoken.IntrospectionResult{Active: false}) + + tests := []struct { + name string + introspectorOpts idptoken.CachingIntrospectorOpts + tokens []string + expectedSrvCalled []bool + expectedResult []idptoken.IntrospectionResult + checkError []func(t *gotesting.T, err error) + checkIntrospector func(t *gotesting.T, introspector *idptoken.CachingIntrospector) + delay time.Duration + }{ + { + name: "error, token is not introspectable", + tokens: []string{"", "opaque-token"}, + expectedSrvCalled: []bool{false, false}, + introspectorOpts: idptoken.CachingIntrospectorOpts{ + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + checkError: []func(t *gotesting.T, err error){ + func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "token is missing") + }, + func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "no JWT header found") + }, + }, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 0, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is expired JWT", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + tokens: repeat(expiredJWT, 2), + expectedSrvCalled: []bool{true, false}, + expectedResult: []idptoken.IntrospectionResult{{Active: false}, {Active: false}}, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is JWT", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + tokens: repeat(activeJWT, 2), + expectedSrvCalled: []bool{true, false}, + expectedResult: repeat(idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + Scope: jwtScope, + }, + }, 2), + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 0, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, static introspection endpoint, introspected token is opaque", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + IntrospectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true}, + }, + tokens: []string{opaqueToken1, opaqueToken1, opaqueToken2, opaqueToken2, opaqueToken3, opaqueToken3}, + expectedSrvCalled: []bool{true, false, true, false, true, false}, + expectedResult: []idptoken.IntrospectionResult{ + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken2Scope}}, + {Active: false}, + {Active: false}, + }, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 2, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + }, + }, + { + name: "ok, cache has ttl", + introspectorOpts: idptoken.CachingIntrospectorOpts{ + IntrospectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + ClaimsCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + NegativeCache: idptoken.CachingIntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + }, + tokens: []string{opaqueToken1, opaqueToken1, opaqueToken3, opaqueToken3}, + expectedSrvCalled: []bool{true, true, true, true}, + expectedResult: []idptoken.IntrospectionResult{ + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, + {Active: false}, + {Active: false}, + }, + checkIntrospector: func(t *gotesting.T, introspector *idptoken.CachingIntrospector) { + require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + }, + delay: 200 * time.Millisecond, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + introspector, err := idptoken.NewCachingIntrospectorWithOpts(tokenProvider, tt.introspectorOpts) + require.NoError(t, err) + require.NoError(t, introspector.AddTrustedIssuerURL(idpSrv.URL())) + + for i, token := range tt.tokens { + serverIntrospector.ResetCallsInfo() + + result, introspectErr := introspector.IntrospectToken(context.Background(), token) + if i < len(tt.checkError) { + tt.checkError[i](t, introspectErr) + } else { + require.NoError(t, introspectErr) + require.Equal(t, tt.expectedResult[i], result) + } + + require.Equal(t, tt.expectedSrvCalled[i], serverIntrospector.Called) + if tt.expectedSrvCalled[i] { + require.Equal(t, token, serverIntrospector.LastIntrospectedToken) + require.Equal(t, "Bearer "+accessToken, serverIntrospector.LastAuthorizationHeader) + require.Equal(t, url.Values{"token": {token}}, serverIntrospector.LastFormValues) + } + + time.Sleep(tt.delay) + } + + if tt.checkIntrospector != nil { + tt.checkIntrospector(t, introspector) + } + }) + } +} + +func repeat[V any](v V, n int) []V { + s := make([]V, n) + for i := range s { + s[i] = v + } + return s +} diff --git a/idptoken/config.go b/idptoken/config.go new file mode 100644 index 0000000..92a3349 --- /dev/null +++ b/idptoken/config.go @@ -0,0 +1,61 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "fmt" + + "github.com/acronis/go-appkit/config" +) + +const ( + cfgKeyIDPURL = "idp.url" + cfgKeyIDPClientID = "idp.clientId" + cfgKeyIDPClientSecret = "idp.clientSecret" +) + +// Config is a configuration for IDP token source. +type Config struct { + URL string + ClientID string + ClientSecret string +} + +var _ config.Config = (*Config)(nil) + +// NewConfig creates a new configuration for IDP token source. +func NewConfig() *Config { + return &Config{} +} + +// SetProviderDefaults sets the default values for the configuration. +func (c *Config) SetProviderDefaults(_ config.DataProvider) { +} + +// Set sets the configuration from the given data provider. +func (c *Config) Set(dp config.DataProvider) (err error) { + if c.URL, err = dp.GetString(cfgKeyIDPURL); err != nil { + return err + } + if c.URL == "" { + return dp.WrapKeyErr(cfgKeyIDPURL, fmt.Errorf("IDP URL is required")) + } + if c.ClientID, err = dp.GetString(cfgKeyIDPClientID); err != nil { + return err + } + if c.ClientID == "" { + return dp.WrapKeyErr(cfgKeyIDPClientID, fmt.Errorf("IDP client ID is required")) + } + if c.ClientSecret, err = dp.GetString(cfgKeyIDPClientSecret); err != nil { + return err + } + if c.ClientSecret == "" { + return dp.WrapKeyErr(cfgKeyIDPClientSecret, fmt.Errorf("IDP client secret is required")) + } + + return nil +} diff --git a/idptoken/config_test.go b/idptoken/config_test.go new file mode 100644 index 0000000..6736ede --- /dev/null +++ b/idptoken/config_test.go @@ -0,0 +1,105 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "bytes" + "os" + "testing" + + "github.com/acronis/go-appkit/config" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptoken" +) + +func TestConfig(t *testing.T) { + type testCase struct { + name string + cfgData string + expectErr bool + errMsg string + setupEnv map[string]string // Environment variables to set + } + + testCases := []testCase{ + { + name: "valid config", + cfgData: ` +idp: + url: https://idp.example.com + clientId: client-id + clientSecret: client-secret +`, + expectErr: false, + }, + { + name: "missing url", + cfgData: ` +idp: + clientId: client-id + clientSecret: client-secret +`, + expectErr: true, + errMsg: `idp.url: IDP URL is required`, + }, + { + name: "missing client ID", + cfgData: ` +idp: + url: https://idp.example.com + clientSecret: client-secret +`, + expectErr: true, + errMsg: `idp.clientId: IDP client ID is required`, + }, + { + name: "missing client secret", + cfgData: ` +idp: + url: https://idp.example.com + clientId: client-id +`, + expectErr: true, + errMsg: `idp.clientSecret: IDP client secret is required`, + }, + { + name: "valid config from Env", + cfgData: ` +idp: + url: https://idp.example.com +`, + setupEnv: map[string]string{ + "IDP_CLIENTID": "client-id", + "IDP_CLIENTSECRET": "client-secret", + }, + expectErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup environment variables if needed + for k, v := range tc.setupEnv { + err := os.Setenv(k, v) + require.NoError(t, err) + defer func(k string) { + require.NoError(t, os.Unsetenv(k)) + }(k) + } + + cfg := idptoken.NewConfig() + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewBufferString(tc.cfgData), config.DataTypeYAML, cfg) + + if tc.expectErr { + require.EqualError(t, err, tc.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/idptoken/doc.go b/idptoken/doc.go new file mode 100644 index 0000000..11bea92 --- /dev/null +++ b/idptoken/doc.go @@ -0,0 +1,10 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package idptoken provides a robust way to request access tokens from IDP. +// Provider is to be used for a single token source. +// MultiSourceProvider to be used for multiple token sources. +package idptoken diff --git a/idptoken/grpc_client.go b/idptoken/grpc_client.go new file mode 100644 index 0000000..c545cea --- /dev/null +++ b/idptoken/grpc_client.go @@ -0,0 +1,182 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "google.golang.org/grpc" + grpccodes "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" + grpcstatus "google.golang.org/grpc/status" + + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/metrics" + "github.com/acronis/go-authkit/jwt" +) + +// GRPCClientOpts contains options for the GRPCClient. +type GRPCClientOpts struct { + // Logger is a logger for the client. + Logger log.FieldLogger + + // RequestTimeout is a timeout for the gRPC requests. + RequestTimeout time.Duration + + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same library. + PrometheusLibInstanceLabel string +} + +// GRPCClient is a client for the IDP token service that uses gRPC. +type GRPCClient struct { + client pb.IDPTokenServiceClient + clientConn *grpc.ClientConn + reqTimeout time.Duration + promMetrics *metrics.PrometheusMetrics +} + +const grpcMetaAuthorization = "authorization" + +// NewGRPCClient creates a new GRPCClient instance that communicates with the IDP token service. +func NewGRPCClient( + target string, transportCreds credentials.TransportCredentials, +) (*GRPCClient, error) { + return NewGRPCClientWithOpts(target, transportCreds, GRPCClientOpts{}) +} + +// NewGRPCClientWithOpts creates a new GRPCClient instance that communicates with the IDP token service +// with the specified options. +func NewGRPCClientWithOpts( + target string, transportCreds credentials.TransportCredentials, opts GRPCClientOpts, +) (*GRPCClient, error) { + if opts.Logger == nil { + opts.Logger = log.NewDisabledLogger() + } + if opts.RequestTimeout == 0 { + opts.RequestTimeout = time.Second * 30 + } + dialCtx := context.Background() // context.Background() is ok since we don't use grpc.WithBlock() + conn, err := grpc.DialContext(dialCtx, target, + grpc.WithTransportCredentials(transportCreds), + grpc.WithStatsHandler(&statsHandler{logger: opts.Logger}), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + ) + if err != nil { + return nil, fmt.Errorf("dial to %q: %w", target, err) + } + return &GRPCClient{ + client: pb.NewIDPTokenServiceClient(conn), + clientConn: conn, + reqTimeout: opts.RequestTimeout, + promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "grpc_client"), + }, nil +} + +// Close closes the client gRPC connection. +func (c *GRPCClient) Close() error { + return c.clientConn.Close() +} + +// IntrospectToken introspects the token using the IDP token service. +func (c *GRPCClient) IntrospectToken( + ctx context.Context, token string, scopeFilter []IntrospectionScopeFilterAccessPolicy, accessToken string, +) (IntrospectionResult, error) { + req := pb.IntrospectTokenRequest{ + Token: token, + ScopeFilter: make([]*pb.IntrospectionScopeFilter, len(scopeFilter)), + } + for i := range scopeFilter { + req.ScopeFilter[i] = &pb.IntrospectionScopeFilter{ResourceNamespace: scopeFilter[i].ResourceNamespace} + } + + ctx = metadata.AppendToOutgoingContext(ctx, grpcMetaAuthorization, makeBearerToken(accessToken)) + + ctx, ctxCancel := context.WithTimeout(ctx, c.reqTimeout) + defer ctxCancel() + + const methodName = "IDPTokenService/IntrospectToken" + startTime := time.Now() + resp, err := c.client.IntrospectToken(ctx, &req) + elapsed := time.Since(startTime) + if err != nil { + var code grpccodes.Code + if st, ok := grpcstatus.FromError(err); ok { + code = st.Code() + } + c.promMetrics.ObserveGRPCClientRequest(methodName, code, elapsed) + if code == grpccodes.Unauthenticated { + return IntrospectionResult{}, ErrTokenIntrospectionUnauthenticated + } + return IntrospectionResult{}, fmt.Errorf("introspect token: %w", err) + } + c.promMetrics.ObserveGRPCClientRequest(methodName, grpccodes.OK, elapsed) + + res := IntrospectionResult{ + Active: resp.GetActive(), + TokenType: resp.GetTokenType(), + Claims: jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: resp.GetIss(), + Subject: resp.GetSub(), + Audience: resp.GetAud(), + ID: resp.GetJti(), + }, + SubType: resp.GetSubType(), + ClientID: resp.GetClientId(), + OwnerTenantUUID: resp.GetOwnerTenantUuid(), + Scope: make([]jwt.AccessPolicy, len(resp.GetScope())), + }, + } + if resp.GetExp() != 0 { + res.Claims.ExpiresAt = jwtgo.NewNumericDate(time.Unix(resp.GetExp(), 0)) + } + for i, s := range resp.GetScope() { + res.Claims.Scope[i] = jwt.AccessPolicy{ + ResourceNamespace: s.GetResourceNamespace(), + Role: s.GetRoleName(), + ResourceServerID: s.GetResourceServer(), + ResourcePath: s.GetResourcePath(), + TenantUUID: s.GetTenantUuid(), + } + if s.GetTenantIntId() != 0 { + res.Claims.Scope[i].TenantID = strconv.FormatInt(s.GetTenantIntId(), 10) + } + } + return res, nil +} + +type statsHandler struct { + logger log.FieldLogger +} + +func (sh *statsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return ctx +} + +func (sh *statsHandler) HandleRPC(ctx context.Context, s stats.RPCStats) { +} + +func (sh *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + return ctx +} + +func (sh *statsHandler) HandleConn(ctx context.Context, s stats.ConnStats) { + switch s.(type) { + case *stats.ConnBegin: + sh.logger.Infof("grpc connection established") + case *stats.ConnEnd: + sh.logger.Infof("grpc connection closed") + } +} diff --git a/idptoken/idp_token.proto b/idptoken/idp_token.proto new file mode 100644 index 0000000..4299e6a --- /dev/null +++ b/idptoken/idp_token.proto @@ -0,0 +1,74 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +syntax = "proto3"; + +package idp_token; + +option go_package = "./pb"; + +service IDPTokenService { + // CreateToken creates a new token based on the provided assertion. + // Currently only "urn:ietf:params:oauth:grant-type:jwt-bearer" grant type is supported. + rpc CreateToken (CreateTokenRequest) returns (CreateTokenResponse); + + // IntrospectToken returns information about the token including its scopes. + // The token is considered active if + // 1) it's not expired; + // 2) it's not revoked; + // 3) it has has the valid signature. + rpc IntrospectToken (IntrospectTokenRequest) returns (IntrospectTokenResponse); +} + +message CreateTokenRequest { + reserved 4 to 50; + string grant_type = 1; // example: urn:ietf:params:oauth:grant-type:jwt-bearer + string assertion = 2; + uint32 token_version = 3; +} + +message CreateTokenResponse { + reserved 4 to 50; + string access_token = 1; + string token_type = 2; + int64 expires_in = 3; +} + +message IntrospectionScopeFilter { + reserved 2 to 50; + string ResourceNamespace = 1; +} + +message IntrospectTokenRequest { + reserved 3 to 50; + string token = 1; + repeated IntrospectionScopeFilter scope_filter = 2; +} + +message AccessTokenScope { + reserved 7 to 50; + string tenant_uuid = 1; + int64 tenant_int_id = 2; + string ResourceServer = 3; + string ResourceNamespace = 4; + string ResourcePath = 5; + string RoleName = 6; +} + +message IntrospectTokenResponse { + reserved 12 to 100; + bool active = 1; + string token_type = 2; + int64 exp = 3; + repeated string aud = 4; + string jti = 5; + string iss = 6; + string sub = 7; + string sub_type = 8; + string client_id = 9; + string owner_tenant_uuid = 10; // API client owner tenant UUID. + repeated AccessTokenScope scope = 11; +} \ No newline at end of file diff --git a/idptoken/introspector.go b/idptoken/introspector.go new file mode 100644 index 0000000..0475e01 --- /dev/null +++ b/idptoken/introspector.go @@ -0,0 +1,402 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" + "github.com/acronis/go-authkit/jwt" +) + +const DefaultRequestTimeout = 30 * time.Second + +const JWTTypeAccessToken = "at+jwt" + +const TokenTypeBearer = "bearer" + +const MinJWTVersionForIntrospection = 0 + +const minAccessTokenProviderInvalidationInterval = time.Minute + +const tokenIntrospectorPromSource = "token_introspector" + +// ErrTokenNotIntrospectable is returned when token is not introspectable. +var ErrTokenNotIntrospectable = errors.New("token is not introspectable") + +// ErrTokenIntrospectionNotNeeded is returned when token introspection is unnecessary +// (i.e., it already contains all necessary information). +var ErrTokenIntrospectionNotNeeded = errors.New("token introspection is not needed") + +// ErrTokenIntrospectionUnauthenticated is returned when token introspection is unauthenticated. +var ErrTokenIntrospectionUnauthenticated = errors.New("token introspection is unauthenticated") + +// TrustedIssNotFoundFallback is a function called when given issuer is not found in the list of trusted ones. +// For example, it could be analyzed and then added to the list by calling AddTrustedIssuerURL method. +type TrustedIssNotFoundFallback func(ctx context.Context, i *Introspector, iss string) (issURL string, issFound bool) + +// IntrospectionTokenProvider is an interface for getting access token for doing introspection. +// The token should have introspection permission. +type IntrospectionTokenProvider interface { + GetToken(ctx context.Context, scope ...string) (string, error) + Invalidate() +} + +// IntrospectionScopeFilterAccessPolicy is an access policy for filtering scopes. +type IntrospectionScopeFilterAccessPolicy struct { + ResourceNamespace string +} + +// IntrospectorOpts is a set of options for creating Introspector. +type IntrospectorOpts struct { + // GRPCClient is a GRPC client for doing introspection. + // If it is set, then introspection will be done using this client. + // Otherwise, introspection will be done via HTTP. + GRPCClient *GRPCClient + + // StaticHTTPEndpoint is a static URL for introspection. + // If it is set, then introspection will be done using this endpoint. + // Otherwise, introspection will be done using issuer URL (/.well-known/openid-configuration response). + // In this case, issuer URL should be present in JWT header or payload. + StaticHTTPEndpoint string + + // HTTPClient is an HTTP client for doing requests to /.well-known/openid-configuration and introspection endpoints. + HTTPClient *http.Client + + // AccessTokenScope is a scope for getting access token for doing introspection. + // The token should have introspection permission. + AccessTokenScope []string + + // ScopeFilter is a list of access policies for filtering scopes during introspection. + // If it is set, then only scopes that match at least one of the policies will be returned. + ScopeFilter []IntrospectionScopeFilterAccessPolicy + + // Logger is a logger for logging errors and debug information. + Logger log.FieldLogger + + // TrustedIssuerNotFoundFallback is a function called + // when given issuer from JWT is not found in the list of trusted ones. + TrustedIssuerNotFoundFallback TrustedIssNotFoundFallback + + // MinJWTVersion is a minimum JWT version for introspection. + // If JWT version is less than this value, then introspection will not be done + // and ErrTokenIntrospectionNotNeeded will be returned. + // Version is a value of "ver" field in JWT header. + // By default, it is 0. + // NOTE: it's a temporary solution for determining whether introspection is needed or not, + // and it will be removed in the future. + MinJWTVersion int + + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same library. + PrometheusLibInstanceLabel string +} + +// Introspector is a struct for introspecting tokens. +type Introspector struct { + accessTokenProvider IntrospectionTokenProvider + accessTokenProviderInvalidatedAt atomic.Value + accessTokenScope []string + + minJWTVersion int + jwtParser *jwtgo.Parser + + grpcClient *GRPCClient + staticHTTPURL string + httpClient *http.Client + + scopeFilter []IntrospectionScopeFilterAccessPolicy + scopeFilterFormURLEncoded string + + logger log.FieldLogger + + trustedIssuerStore *idputil.TrustedIssuerStore + trustedIssuerNotFoundFallback TrustedIssNotFoundFallback + + promMetrics *metrics.PrometheusMetrics +} + +// IntrospectionResult is a struct for introspection result. +type IntrospectionResult struct { + Active bool `json:"active"` + TokenType string `json:"token_type,omitempty"` + jwt.Claims +} + +// NewIntrospector creates a new Introspector with the given token provider. +func NewIntrospector(tokenProvider IntrospectionTokenProvider) *Introspector { + return NewIntrospectorWithOpts(tokenProvider, IntrospectorOpts{}) +} + +// NewIntrospectorWithOpts creates a new Introspector with the given token provider and options. +// See IntrospectorOpts for more details. +func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opts IntrospectorOpts) *Introspector { + if opts.HTTPClient == nil { + opts.HTTPClient = &http.Client{Timeout: DefaultRequestTimeout} + } + if opts.Logger == nil { + opts.Logger = log.NewDisabledLogger() + } + + values := url.Values{} + for i, policy := range opts.ScopeFilter { + values.Set("scope_filter["+strconv.Itoa(i)+"].rn", policy.ResourceNamespace) + } + scopeFilterFormURLEncoded := values.Encode() + + promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource) + + return &Introspector{ + accessTokenProvider: accessTokenProvider, + accessTokenScope: opts.AccessTokenScope, + jwtParser: jwtgo.NewParser(), + logger: opts.Logger, + grpcClient: opts.GRPCClient, + httpClient: opts.HTTPClient, + scopeFilterFormURLEncoded: scopeFilterFormURLEncoded, + scopeFilter: opts.ScopeFilter, + staticHTTPURL: opts.StaticHTTPEndpoint, + minJWTVersion: opts.MinJWTVersion, + trustedIssuerStore: idputil.NewTrustedIssuerStore(), + trustedIssuerNotFoundFallback: opts.TrustedIssuerNotFoundFallback, + promMetrics: promMetrics, + } +} + +// IntrospectToken introspects the given token. +func (i *Introspector) IntrospectToken(ctx context.Context, token string) (IntrospectionResult, error) { + introspectFn, err := i.makeIntrospectFuncForToken(ctx, token) + if err != nil { + return IntrospectionResult{}, err + } + + result, err := introspectFn(ctx, token) + if err == nil { + return result, nil + } + + if !errors.Is(err, ErrTokenIntrospectionUnauthenticated) { + return IntrospectionResult{}, err + } + + // If introspection is unauthorized, then invalidate access token (if it is not invalidated recently) and try again. + t, ok := i.accessTokenProviderInvalidatedAt.Load().(time.Time) + now := time.Now() + if !ok || now.Sub(t) > minAccessTokenProviderInvalidationInterval { + i.accessTokenProvider.Invalidate() + i.accessTokenProviderInvalidatedAt.Store(now) + return introspectFn(ctx, token) + } + return IntrospectionResult{}, err +} + +// AddTrustedIssuer adds trusted issuer with specified name and URL. +func (i *Introspector) AddTrustedIssuer(issName, issURL string) { + i.trustedIssuerStore.AddTrustedIssuer(issName, issURL) +} + +// AddTrustedIssuerURL adds trusted issuer URL. +func (i *Introspector) AddTrustedIssuerURL(issURL string) error { + return i.trustedIssuerStore.AddTrustedIssuerURL(issURL) +} + +type introspectFunc func(ctx context.Context, token string) (IntrospectionResult, error) + +func (i *Introspector) makeIntrospectFuncForToken(ctx context.Context, token string) (introspectFunc, error) { + var err error + + if token == "" { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("token is missing")) + } + + jwtHeaderEndIdx := strings.IndexByte(token, '.') + if jwtHeaderEndIdx == -1 { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("no JWT header found")) + } + var jwtHeaderBytes []byte + if jwtHeaderBytes, err = i.jwtParser.DecodeSegment(token[:jwtHeaderEndIdx]); err != nil { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("decode JWT header: %w", err)) + } + jwtHeader := make(map[string]interface{}) + if err = json.Unmarshal(jwtHeaderBytes, &jwtHeader); err != nil { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("unmarshal JWT header: %w", err)) + } + if typ, ok := jwtHeader["typ"].(string); !ok || !strings.EqualFold(typ, JWTTypeAccessToken) { + return i.makeStaticIntrospectFuncOrError(fmt.Errorf("token type is not %s", JWTTypeAccessToken)) + } + if ver, ok := jwtHeader["ver"].(float64); ok && int(ver) < i.minJWTVersion { + return nil, ErrTokenIntrospectionNotNeeded + } + + if i.staticHTTPURL != "" { + return i.makeIntrospectFuncHTTP(i.staticHTTPURL), nil + } + if i.grpcClient != nil { + return i.makeIntrospectFuncGRPC(), nil + } + + // Try to get issuer from JWT header first and then from JWT payload. + // Issuer is usually presented in the JWT payload (it's an optional field, according to RFC 7519), + // but it could be in the header as well for optimization purposes. + // It's relevant for JWTs with large payloads. + issuer, ok := jwtHeader["iss"].(string) + if !ok || issuer == "" { + jwtPayloadEndIdx := strings.IndexByte(token[jwtHeaderEndIdx+1:], '.') + if jwtPayloadEndIdx == -1 { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("no JWT payload found")) + } + var jwtPayloadBytes []byte + if jwtPayloadBytes, err = i.jwtParser.DecodeSegment( + token[jwtHeaderEndIdx+1 : jwtHeaderEndIdx+1+jwtPayloadEndIdx], + ); err != nil { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("decode JWT payload: %w", err)) + } + var originalClaims jwt.Claims + if err = json.Unmarshal(jwtPayloadBytes, &originalClaims); err != nil { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("unmarshal JWT payload: %w", err)) + } + if originalClaims.Issuer == "" { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("no issuer found in JWT")) + } + issuer = originalClaims.Issuer + } + + issuerURL, ok := i.getURLForIssuerWithCallback(ctx, issuer) + if !ok { + return nil, makeTokenNotIntrospectableError(fmt.Errorf("issuer %q is not trusted", issuer)) + } + + // Try to get introspection endpoint URL from issuer. + introspectionEndpointURL, err := i.getWellKnownIntrospectionEndpointURL(ctx, issuerURL) + if err != nil { + return nil, fmt.Errorf("get introspection endpoint URL: %w", err) + } + return i.makeIntrospectFuncHTTP(introspectionEndpointURL), nil +} + +func (i *Introspector) makeStaticIntrospectFuncOrError(inner error) (introspectFunc, error) { + if i.grpcClient != nil { + return i.makeIntrospectFuncGRPC(), nil + } + if i.staticHTTPURL != "" { + return i.makeIntrospectFuncHTTP(i.staticHTTPURL), nil + } + return nil, makeTokenNotIntrospectableError(inner) +} + +func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) introspectFunc { + return func(ctx context.Context, token string) (IntrospectionResult, error) { + accessToken, err := i.accessTokenProvider.GetToken(ctx, i.accessTokenScope...) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("get access token for doing introspection: %w", err) + } + formEncoded := url.Values{"token": {token}}.Encode() + if i.scopeFilterFormURLEncoded != "" { + formEncoded += "&" + i.scopeFilterFormURLEncoded + } + req, err := http.NewRequest(http.MethodPost, introspectionEndpointURL, strings.NewReader(formEncoded)) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("new request: %w", err) + } + req.Header.Set("Authorization", makeBearerToken(accessToken)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + startTime := time.Now() + resp, err := i.httpClient.Do(req.WithContext(ctx)) + elapsed := time.Since(startTime) + if err != nil { + i.promMetrics.ObserveHTTPClientRequest(http.MethodPost, introspectionEndpointURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return IntrospectionResult{}, fmt.Errorf("do request: %w", err) + } + defer func() { + if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { + i.logger.Error(fmt.Sprintf("closing response body error for POST %s", introspectionEndpointURL), + log.Error(closeBodyErr)) + } + }() + if resp.StatusCode != http.StatusOK { + i.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + if resp.StatusCode == http.StatusUnauthorized { + return IntrospectionResult{}, ErrTokenIntrospectionUnauthenticated + } + return IntrospectionResult{}, fmt.Errorf("unexpected HTTP code %d for POST %s", resp.StatusCode, introspectionEndpointURL) + } + + var res IntrospectionResult + if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { + i.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return IntrospectionResult{}, fmt.Errorf("decode response body json for POST %s: %w", introspectionEndpointURL, err) + } + + i.promMetrics.ObserveHTTPClientRequest(http.MethodPost, introspectionEndpointURL, resp.StatusCode, elapsed, "") + return res, nil + } +} + +func (i *Introspector) makeIntrospectFuncGRPC() introspectFunc { + return func(ctx context.Context, token string) (IntrospectionResult, error) { + accessToken, err := i.accessTokenProvider.GetToken(ctx, i.accessTokenScope...) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("get access token for doing introspection: %w", err) + } + res, err := i.grpcClient.IntrospectToken(ctx, token, i.scopeFilter, accessToken) + if err != nil { + return IntrospectionResult{}, fmt.Errorf("introspect token: %w", err) + } + return res, nil + } +} + +func (i *Introspector) getWellKnownIntrospectionEndpointURL(ctx context.Context, issuerURL string) (string, error) { + openIDCfgURL := strings.TrimSuffix(issuerURL, "/") + wellKnownPath + openIDCfg, err := idputil.GetOpenIDConfiguration( + ctx, i.httpClient, openIDCfgURL, nil, i.logger, i.promMetrics) + if err != nil { + return "", fmt.Errorf("get OpenID configuration: %w", err) + } + if openIDCfg.IntrospectionEndpoint == "" { + return "", fmt.Errorf("no introspection endpoint URL found on %s", openIDCfgURL) + } + return openIDCfg.IntrospectionEndpoint, nil +} + +func (i *Introspector) getURLForIssuerWithCallback(ctx context.Context, issuer string) (string, bool) { + issURL, issFound := i.trustedIssuerStore.GetURLForIssuer(issuer) + if issFound { + return issURL, true + } + if i.trustedIssuerNotFoundFallback == nil { + return "", false + } + return i.trustedIssuerNotFoundFallback(ctx, i, issuer) +} + +func makeTokenNotIntrospectableError(inner error) error { + if inner != nil { + return fmt.Errorf("%w: %w", ErrTokenNotIntrospectable, inner) + } + return ErrTokenNotIntrospectable +} + +func makeBearerToken(token string) string { + return "Bearer " + token +} diff --git a/idptoken/introspector_test.go b/idptoken/introspector_test.go new file mode 100644 index 0000000..5757b0c --- /dev/null +++ b/idptoken/introspector_test.go @@ -0,0 +1,306 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "context" + "net/http" + "net/url" + gotesting "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/credentials/insecure" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/testing" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +func TestIntrospector_IntrospectToken(t *gotesting.T) { + httpServerIntrospector := testing.NewHTTPServerTokenIntrospectorMock() + grpcServerIntrospector := testing.NewGRPCServerTokenIntrospectorMock() + + httpIDPSrv := idptest.NewHTTPServer(idptest.WithHTTPTokenIntrospector(httpServerIntrospector)) + require.NoError(t, httpIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { _ = httpIDPSrv.Shutdown(context.Background()) }() + + grpcIDPSrv := idptest.NewGRPCServer(idptest.WithGRPCTokenIntrospector(grpcServerIntrospector)) + require.NoError(t, grpcIDPSrv.StartAndWaitForReady(time.Second)) + defer func() { grpcIDPSrv.GracefulStop() }() + + const accessToken = "access-token-with-introspection-permission" + tokenProvider := idptest.NewSimpleTokenProvider(accessToken) + + logger := log.NewDisabledLogger() + jwtParser := jwt.NewParser(jwks.NewClient(http.DefaultClient, logger), logger) + require.NoError(t, jwtParser.AddTrustedIssuerURL(httpIDPSrv.URL())) + httpServerIntrospector.JWTParser = jwtParser + grpcServerIntrospector.JWTParser = jwtParser + + jwtScopeToGRPC := func(jwtScope []jwt.AccessPolicy) []*pb.AccessTokenScope { + grpcScope := make([]*pb.AccessTokenScope, len(jwtScope)) + for i, scope := range jwtScope { + grpcScope[i] = &pb.AccessTokenScope{ + TenantUuid: scope.TenantUUID, + ResourceNamespace: scope.ResourceNamespace, + RoleName: scope.Role, + ResourcePath: scope.ResourcePath, + } + } + return grpcScope + } + + jwtExpiresAtInFuture := jwtgo.NewNumericDate(time.Now().Add(time.Hour)) + jwtIssuer := httpIDPSrv.URL() + jwtSubject := uuid.NewString() + jwtID := uuid.NewString() + jwtScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-" + uuid.NewString(), + }} + + opaqueToken := "opaque-token-" + uuid.NewString() + opaqueTokenScope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "admin", + ResourcePath: "resource-" + uuid.NewString(), + }} + + httpServerIntrospector.SetScopeForJWTID(jwtID, jwtScope) + httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ + Active: true, TokenType: idptoken.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) + grpcServerIntrospector.SetScopeForJWTID(jwtID, jwtScopeToGRPC(jwtScope)) + grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{ + Active: true, TokenType: idptoken.TokenTypeBearer, Scope: jwtScopeToGRPC(opaqueTokenScope)}) + + tests := []struct { + name string + introspectorOpts idptoken.IntrospectorOpts + useGRPC bool + token string + expectedResult idptoken.IntrospectionResult + checkError func(t *gotesting.T, err error) + expectedHTTPSrvCalled bool + expectedHTTPFormVals url.Values + expectedGRPCSrvCalled bool + expectedGRPCScopeFilter []*pb.IntrospectionScopeFilter + }{ + { + name: "error, token is missing", + token: "", + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "token is missing") + }, + }, + { + name: "error, dynamic introspection endpoint, no jwt header", + token: "opaque-token", + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "no JWT header found") + }, + }, + { + name: "error, dynamic introspection endpoint, cannot decode jwt header", + token: "$opaque$.$token$", + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "decode JWT header") + }, + }, + { + name: "error, dynamic introspection endpoint, issuer is not trusted", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: "https://untrusted-issuer.com", + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + }, + }), + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, `issuer "https://untrusted-issuer.com" is not trusted`) + }, + }, + { + name: "error, dynamic introspection endpoint, issuer is missing in JWT header and payload", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + }, + }), + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenNotIntrospectable) + require.ErrorContains(t, err, "no issuer found in JWT") + }, + }, + { + name: "error, dynamic introspection endpoint, introspection is not needed", + introspectorOpts: idptoken.IntrospectorOpts{ + MinJWTVersion: 3, + }, + token: idptest.MustMakeTokenStringWithHeader(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + }, + }, idptest.TestKeyID, idptest.GetTestRSAPrivateKey(), map[string]interface{}{"ver": 2}), + checkError: func(t *gotesting.T, err error) { + require.ErrorIs(t, err, idptoken.ErrTokenIntrospectionNotNeeded) + }, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is expired JWT", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + }, + }), + expectedResult: idptoken.IntrospectionResult{Active: false}, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, dynamic introspection endpoint, introspected token is JWT", + token: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + }), + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: jwtIssuer, + Subject: jwtSubject, + ID: jwtID, + ExpiresAt: jwtExpiresAtInFuture, + }, + Scope: jwtScope, + }, + }, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, static introspection endpoint, introspected token is opaque", + introspectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + expectedHTTPSrvCalled: true, + }, + { + name: "ok, static introspection endpoint, introspected token is opaque, filter scope by resource namespace", + introspectorOpts: idptoken.IntrospectorOpts{ + StaticHTTPEndpoint: httpIDPSrv.URL() + idptest.TokenIntrospectionEndpointPath, + ScopeFilter: []idptoken.IntrospectionScopeFilterAccessPolicy{ + {ResourceNamespace: "account-server"}, + {ResourceNamespace: "tenant-manager"}, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + expectedHTTPSrvCalled: true, + expectedHTTPFormVals: url.Values{ + "token": {opaqueToken}, + "scope_filter[0].rn": {"account-server"}, + "scope_filter[1].rn": {"tenant-manager"}, + }, + }, + { + name: "ok, grpc introspection endpoint", + useGRPC: true, + introspectorOpts: idptoken.IntrospectorOpts{ + ScopeFilter: []idptoken.IntrospectionScopeFilterAccessPolicy{ + {ResourceNamespace: "account-server"}, + {ResourceNamespace: "tenant-manager"}, + }, + }, + token: opaqueToken, + expectedResult: idptoken.IntrospectionResult{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Claims: jwt.Claims{Scope: opaqueTokenScope}, + }, + expectedGRPCSrvCalled: true, + expectedGRPCScopeFilter: []*pb.IntrospectionScopeFilter{ + {ResourceNamespace: "account-server"}, + {ResourceNamespace: "tenant-manager"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *gotesting.T) { + if tt.useGRPC { + grpcClient, err := idptoken.NewGRPCClient(grpcIDPSrv.Addr(), insecure.NewCredentials()) + require.NoError(t, err) + defer func() { require.NoError(t, grpcClient.Close()) }() + tt.introspectorOpts.GRPCClient = grpcClient + } + introspector := idptoken.NewIntrospectorWithOpts(tokenProvider, tt.introspectorOpts) + require.NoError(t, introspector.AddTrustedIssuerURL(httpIDPSrv.URL())) + + httpServerIntrospector.ResetCallsInfo() + grpcServerIntrospector.ResetCallsInfo() + + result, err := introspector.IntrospectToken(context.Background(), tt.token) + if tt.checkError != nil { + tt.checkError(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedResult, result) + } + + require.Equal(t, tt.expectedHTTPSrvCalled, httpServerIntrospector.Called) + if tt.expectedHTTPSrvCalled { + require.Equal(t, tt.token, httpServerIntrospector.LastIntrospectedToken) + require.Equal(t, "Bearer "+accessToken, httpServerIntrospector.LastAuthorizationHeader) + if tt.expectedHTTPFormVals == nil { + tt.expectedHTTPFormVals = url.Values{"token": {tt.token}} + } + require.Equal(t, tt.expectedHTTPFormVals, httpServerIntrospector.LastFormValues) + } + + require.Equal(t, tt.expectedGRPCSrvCalled, grpcServerIntrospector.Called) + if tt.expectedGRPCSrvCalled { + require.Equal(t, tt.token, grpcServerIntrospector.LastRequest.Token) + require.Equal(t, tt.expectedGRPCScopeFilter, grpcServerIntrospector.LastRequest.GetScopeFilter()) + require.Equal(t, "Bearer "+accessToken, grpcServerIntrospector.LastAuthorizationMeta) + } + }) + } +} diff --git a/idptoken/pb/idp_token.pb.go b/idptoken/pb/idp_token.pb.go new file mode 100644 index 0000000..0feb757 --- /dev/null +++ b/idptoken/pb/idp_token.pb.go @@ -0,0 +1,682 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.33.0 +// protoc v5.26.1 +// source: idp_token.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type CreateTokenRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + GrantType string `protobuf:"bytes,1,opt,name=grant_type,json=grantType,proto3" json:"grant_type,omitempty"` // example: urn:ietf:params:oauth:grant-type:jwt-bearer + Assertion string `protobuf:"bytes,2,opt,name=assertion,proto3" json:"assertion,omitempty"` + TokenVersion uint32 `protobuf:"varint,3,opt,name=token_version,json=tokenVersion,proto3" json:"token_version,omitempty"` +} + +func (x *CreateTokenRequest) Reset() { + *x = CreateTokenRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateTokenRequest) ProtoMessage() {} + +func (x *CreateTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateTokenRequest.ProtoReflect.Descriptor instead. +func (*CreateTokenRequest) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{0} +} + +func (x *CreateTokenRequest) GetGrantType() string { + if x != nil { + return x.GrantType + } + return "" +} + +func (x *CreateTokenRequest) GetAssertion() string { + if x != nil { + return x.Assertion + } + return "" +} + +func (x *CreateTokenRequest) GetTokenVersion() uint32 { + if x != nil { + return x.TokenVersion + } + return 0 +} + +type CreateTokenResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` + TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` + ExpiresIn int64 `protobuf:"varint,3,opt,name=expires_in,json=expiresIn,proto3" json:"expires_in,omitempty"` +} + +func (x *CreateTokenResponse) Reset() { + *x = CreateTokenResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateTokenResponse) ProtoMessage() {} + +func (x *CreateTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateTokenResponse.ProtoReflect.Descriptor instead. +func (*CreateTokenResponse) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{1} +} + +func (x *CreateTokenResponse) GetAccessToken() string { + if x != nil { + return x.AccessToken + } + return "" +} + +func (x *CreateTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *CreateTokenResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + +type IntrospectionScopeFilter struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ResourceNamespace string `protobuf:"bytes,1,opt,name=ResourceNamespace,proto3" json:"ResourceNamespace,omitempty"` +} + +func (x *IntrospectionScopeFilter) Reset() { + *x = IntrospectionScopeFilter{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IntrospectionScopeFilter) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IntrospectionScopeFilter) ProtoMessage() {} + +func (x *IntrospectionScopeFilter) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IntrospectionScopeFilter.ProtoReflect.Descriptor instead. +func (*IntrospectionScopeFilter) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{2} +} + +func (x *IntrospectionScopeFilter) GetResourceNamespace() string { + if x != nil { + return x.ResourceNamespace + } + return "" +} + +type IntrospectTokenRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + ScopeFilter []*IntrospectionScopeFilter `protobuf:"bytes,2,rep,name=scope_filter,json=scopeFilter,proto3" json:"scope_filter,omitempty"` +} + +func (x *IntrospectTokenRequest) Reset() { + *x = IntrospectTokenRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IntrospectTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IntrospectTokenRequest) ProtoMessage() {} + +func (x *IntrospectTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IntrospectTokenRequest.ProtoReflect.Descriptor instead. +func (*IntrospectTokenRequest) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{3} +} + +func (x *IntrospectTokenRequest) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *IntrospectTokenRequest) GetScopeFilter() []*IntrospectionScopeFilter { + if x != nil { + return x.ScopeFilter + } + return nil +} + +type AccessTokenScope struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TenantUuid string `protobuf:"bytes,1,opt,name=tenant_uuid,json=tenantUuid,proto3" json:"tenant_uuid,omitempty"` + TenantIntId int64 `protobuf:"varint,2,opt,name=tenant_int_id,json=tenantIntId,proto3" json:"tenant_int_id,omitempty"` + ResourceServer string `protobuf:"bytes,3,opt,name=ResourceServer,proto3" json:"ResourceServer,omitempty"` + ResourceNamespace string `protobuf:"bytes,4,opt,name=ResourceNamespace,proto3" json:"ResourceNamespace,omitempty"` + ResourcePath string `protobuf:"bytes,5,opt,name=ResourcePath,proto3" json:"ResourcePath,omitempty"` + RoleName string `protobuf:"bytes,6,opt,name=RoleName,proto3" json:"RoleName,omitempty"` +} + +func (x *AccessTokenScope) Reset() { + *x = AccessTokenScope{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AccessTokenScope) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AccessTokenScope) ProtoMessage() {} + +func (x *AccessTokenScope) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AccessTokenScope.ProtoReflect.Descriptor instead. +func (*AccessTokenScope) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{4} +} + +func (x *AccessTokenScope) GetTenantUuid() string { + if x != nil { + return x.TenantUuid + } + return "" +} + +func (x *AccessTokenScope) GetTenantIntId() int64 { + if x != nil { + return x.TenantIntId + } + return 0 +} + +func (x *AccessTokenScope) GetResourceServer() string { + if x != nil { + return x.ResourceServer + } + return "" +} + +func (x *AccessTokenScope) GetResourceNamespace() string { + if x != nil { + return x.ResourceNamespace + } + return "" +} + +func (x *AccessTokenScope) GetResourcePath() string { + if x != nil { + return x.ResourcePath + } + return "" +} + +func (x *AccessTokenScope) GetRoleName() string { + if x != nil { + return x.RoleName + } + return "" +} + +type IntrospectTokenResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Active bool `protobuf:"varint,1,opt,name=active,proto3" json:"active,omitempty"` + TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` + Exp int64 `protobuf:"varint,3,opt,name=exp,proto3" json:"exp,omitempty"` + Aud []string `protobuf:"bytes,4,rep,name=aud,proto3" json:"aud,omitempty"` + Jti string `protobuf:"bytes,5,opt,name=jti,proto3" json:"jti,omitempty"` + Iss string `protobuf:"bytes,6,opt,name=iss,proto3" json:"iss,omitempty"` + Sub string `protobuf:"bytes,7,opt,name=sub,proto3" json:"sub,omitempty"` + SubType string `protobuf:"bytes,8,opt,name=sub_type,json=subType,proto3" json:"sub_type,omitempty"` + ClientId string `protobuf:"bytes,9,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + OwnerTenantUuid string `protobuf:"bytes,10,opt,name=owner_tenant_uuid,json=ownerTenantUuid,proto3" json:"owner_tenant_uuid,omitempty"` // API client owner tenant UUID. + Scope []*AccessTokenScope `protobuf:"bytes,11,rep,name=scope,proto3" json:"scope,omitempty"` +} + +func (x *IntrospectTokenResponse) Reset() { + *x = IntrospectTokenResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_idp_token_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IntrospectTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IntrospectTokenResponse) ProtoMessage() {} + +func (x *IntrospectTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_idp_token_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IntrospectTokenResponse.ProtoReflect.Descriptor instead. +func (*IntrospectTokenResponse) Descriptor() ([]byte, []int) { + return file_idp_token_proto_rawDescGZIP(), []int{5} +} + +func (x *IntrospectTokenResponse) GetActive() bool { + if x != nil { + return x.Active + } + return false +} + +func (x *IntrospectTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *IntrospectTokenResponse) GetExp() int64 { + if x != nil { + return x.Exp + } + return 0 +} + +func (x *IntrospectTokenResponse) GetAud() []string { + if x != nil { + return x.Aud + } + return nil +} + +func (x *IntrospectTokenResponse) GetJti() string { + if x != nil { + return x.Jti + } + return "" +} + +func (x *IntrospectTokenResponse) GetIss() string { + if x != nil { + return x.Iss + } + return "" +} + +func (x *IntrospectTokenResponse) GetSub() string { + if x != nil { + return x.Sub + } + return "" +} + +func (x *IntrospectTokenResponse) GetSubType() string { + if x != nil { + return x.SubType + } + return "" +} + +func (x *IntrospectTokenResponse) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *IntrospectTokenResponse) GetOwnerTenantUuid() string { + if x != nil { + return x.OwnerTenantUuid + } + return "" +} + +func (x *IntrospectTokenResponse) GetScope() []*AccessTokenScope { + if x != nil { + return x.Scope + } + return nil +} + +var File_idp_token_proto protoreflect.FileDescriptor + +var file_idp_token_proto_rawDesc = []byte{ + 0x0a, 0x0f, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x09, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x7c, 0x0a, 0x12, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x73, 0x73, 0x65, 0x72, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x73, 0x73, 0x65, 0x72, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x23, 0x0a, 0x0d, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x4a, 0x04, 0x08, 0x04, 0x10, 0x33, 0x22, 0x7c, 0x0a, 0x13, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74, 0x79, + 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x54, + 0x79, 0x70, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x69, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, + 0x49, 0x6e, 0x4a, 0x04, 0x08, 0x04, 0x10, 0x33, 0x22, 0x4e, 0x0a, 0x18, 0x49, 0x6e, 0x74, 0x72, + 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x46, 0x69, + 0x6c, 0x74, 0x65, 0x72, 0x12, 0x2c, 0x0a, 0x11, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x11, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, + 0x63, 0x65, 0x4a, 0x04, 0x08, 0x02, 0x10, 0x33, 0x22, 0x7c, 0x0a, 0x16, 0x49, 0x6e, 0x74, 0x72, + 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x46, 0x0a, 0x0c, 0x73, 0x63, 0x6f, 0x70, + 0x65, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, + 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, + 0x73, 0x70, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x46, 0x69, 0x6c, + 0x74, 0x65, 0x72, 0x52, 0x0b, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, + 0x4a, 0x04, 0x08, 0x03, 0x10, 0x33, 0x22, 0xf3, 0x01, 0x0a, 0x10, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x74, + 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x55, 0x75, 0x69, 0x64, 0x12, 0x22, 0x0a, 0x0d, + 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x0b, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x6e, 0x74, 0x49, 0x64, + 0x12, 0x26, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x2c, 0x0a, 0x11, 0x52, 0x65, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x11, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, + 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, + 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x1a, 0x0a, 0x08, 0x52, 0x6f, + 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x52, 0x6f, + 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x4a, 0x04, 0x08, 0x07, 0x10, 0x33, 0x22, 0xc7, 0x02, 0x0a, + 0x17, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, + 0x76, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, + 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x10, 0x0a, 0x03, 0x65, 0x78, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x65, 0x78, + 0x70, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x75, 0x64, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, + 0x61, 0x75, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6a, 0x74, 0x69, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x6a, 0x74, 0x69, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x73, 0x73, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x69, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x75, 0x62, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x75, 0x62, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x75, 0x62, + 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x75, 0x62, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x12, 0x2a, 0x0a, 0x11, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x5f, 0x74, 0x65, 0x6e, 0x61, 0x6e, + 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x6f, 0x77, + 0x6e, 0x65, 0x72, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x55, 0x75, 0x69, 0x64, 0x12, 0x31, 0x0a, + 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x69, + 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, + 0x4a, 0x04, 0x08, 0x0c, 0x10, 0x65, 0x32, 0xb9, 0x01, 0x0a, 0x0f, 0x49, 0x44, 0x50, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4c, 0x0a, 0x0b, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1d, 0x2e, 0x69, 0x64, 0x70, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x58, 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x72, + 0x6f, 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x21, 0x2e, 0x69, 0x64, + 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, 0x73, 0x70, 0x65, + 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, + 0x2e, 0x69, 0x64, 0x70, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x72, 0x6f, + 0x73, 0x70, 0x65, 0x63, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_idp_token_proto_rawDescOnce sync.Once + file_idp_token_proto_rawDescData = file_idp_token_proto_rawDesc +) + +func file_idp_token_proto_rawDescGZIP() []byte { + file_idp_token_proto_rawDescOnce.Do(func() { + file_idp_token_proto_rawDescData = protoimpl.X.CompressGZIP(file_idp_token_proto_rawDescData) + }) + return file_idp_token_proto_rawDescData +} + +var file_idp_token_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_idp_token_proto_goTypes = []interface{}{ + (*CreateTokenRequest)(nil), // 0: idp_token.CreateTokenRequest + (*CreateTokenResponse)(nil), // 1: idp_token.CreateTokenResponse + (*IntrospectionScopeFilter)(nil), // 2: idp_token.IntrospectionScopeFilter + (*IntrospectTokenRequest)(nil), // 3: idp_token.IntrospectTokenRequest + (*AccessTokenScope)(nil), // 4: idp_token.AccessTokenScope + (*IntrospectTokenResponse)(nil), // 5: idp_token.IntrospectTokenResponse +} +var file_idp_token_proto_depIdxs = []int32{ + 2, // 0: idp_token.IntrospectTokenRequest.scope_filter:type_name -> idp_token.IntrospectionScopeFilter + 4, // 1: idp_token.IntrospectTokenResponse.scope:type_name -> idp_token.AccessTokenScope + 0, // 2: idp_token.IDPTokenService.CreateToken:input_type -> idp_token.CreateTokenRequest + 3, // 3: idp_token.IDPTokenService.IntrospectToken:input_type -> idp_token.IntrospectTokenRequest + 1, // 4: idp_token.IDPTokenService.CreateToken:output_type -> idp_token.CreateTokenResponse + 5, // 5: idp_token.IDPTokenService.IntrospectToken:output_type -> idp_token.IntrospectTokenResponse + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_idp_token_proto_init() } +func file_idp_token_proto_init() { + if File_idp_token_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_idp_token_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateTokenRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateTokenResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IntrospectionScopeFilter); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IntrospectTokenRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AccessTokenScope); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idp_token_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IntrospectTokenResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_idp_token_proto_rawDesc, + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_idp_token_proto_goTypes, + DependencyIndexes: file_idp_token_proto_depIdxs, + MessageInfos: file_idp_token_proto_msgTypes, + }.Build() + File_idp_token_proto = out.File + file_idp_token_proto_rawDesc = nil + file_idp_token_proto_goTypes = nil + file_idp_token_proto_depIdxs = nil +} diff --git a/idptoken/pb/idp_token_grpc.pb.go b/idptoken/pb/idp_token_grpc.pb.go new file mode 100644 index 0000000..a338bb6 --- /dev/null +++ b/idptoken/pb/idp_token_grpc.pb.go @@ -0,0 +1,154 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.3.0 +// - protoc v5.26.1 +// source: idp_token.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +const ( + IDPTokenService_CreateToken_FullMethodName = "/idp_token.IDPTokenService/CreateToken" + IDPTokenService_IntrospectToken_FullMethodName = "/idp_token.IDPTokenService/IntrospectToken" +) + +// IDPTokenServiceClient is the client API for IDPTokenService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type IDPTokenServiceClient interface { + // CreateToken creates a new token based on the provided assertion. + // Now only JWT_BEARER (urn:ietf:params:oauth:grant-type:jwt-bearer) grant type is supported. + CreateToken(ctx context.Context, in *CreateTokenRequest, opts ...grpc.CallOption) (*CreateTokenResponse, error) + // IntrospectToken returns information about the token including its scopes. + // The token is considered active if it 1) is not expired; 2) is not revoked; 3) has has the valid signature. + IntrospectToken(ctx context.Context, in *IntrospectTokenRequest, opts ...grpc.CallOption) (*IntrospectTokenResponse, error) +} + +type iDPTokenServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewIDPTokenServiceClient(cc grpc.ClientConnInterface) IDPTokenServiceClient { + return &iDPTokenServiceClient{cc} +} + +func (c *iDPTokenServiceClient) CreateToken(ctx context.Context, in *CreateTokenRequest, opts ...grpc.CallOption) (*CreateTokenResponse, error) { + out := new(CreateTokenResponse) + err := c.cc.Invoke(ctx, IDPTokenService_CreateToken_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *iDPTokenServiceClient) IntrospectToken(ctx context.Context, in *IntrospectTokenRequest, opts ...grpc.CallOption) (*IntrospectTokenResponse, error) { + out := new(IntrospectTokenResponse) + err := c.cc.Invoke(ctx, IDPTokenService_IntrospectToken_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// IDPTokenServiceServer is the server API for IDPTokenService service. +// All implementations must embed UnimplementedIDPTokenServiceServer +// for forward compatibility +type IDPTokenServiceServer interface { + // CreateToken creates a new token based on the provided assertion. + // Now only JWT_BEARER (urn:ietf:params:oauth:grant-type:jwt-bearer) grant type is supported. + CreateToken(context.Context, *CreateTokenRequest) (*CreateTokenResponse, error) + // IntrospectToken returns information about the token including its scopes. + // The token is considered active if it 1) is not expired; 2) is not revoked; 3) has has the valid signature. + IntrospectToken(context.Context, *IntrospectTokenRequest) (*IntrospectTokenResponse, error) + mustEmbedUnimplementedIDPTokenServiceServer() +} + +// UnimplementedIDPTokenServiceServer must be embedded to have forward compatible implementations. +type UnimplementedIDPTokenServiceServer struct { +} + +func (UnimplementedIDPTokenServiceServer) CreateToken(context.Context, *CreateTokenRequest) (*CreateTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateToken not implemented") +} +func (UnimplementedIDPTokenServiceServer) IntrospectToken(context.Context, *IntrospectTokenRequest) (*IntrospectTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method IntrospectToken not implemented") +} +func (UnimplementedIDPTokenServiceServer) mustEmbedUnimplementedIDPTokenServiceServer() {} + +// UnsafeIDPTokenServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to IDPTokenServiceServer will +// result in compilation errors. +type UnsafeIDPTokenServiceServer interface { + mustEmbedUnimplementedIDPTokenServiceServer() +} + +func RegisterIDPTokenServiceServer(s grpc.ServiceRegistrar, srv IDPTokenServiceServer) { + s.RegisterService(&IDPTokenService_ServiceDesc, srv) +} + +func _IDPTokenService_CreateToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(IDPTokenServiceServer).CreateToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: IDPTokenService_CreateToken_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(IDPTokenServiceServer).CreateToken(ctx, req.(*CreateTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _IDPTokenService_IntrospectToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(IntrospectTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(IDPTokenServiceServer).IntrospectToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: IDPTokenService_IntrospectToken_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(IDPTokenServiceServer).IntrospectToken(ctx, req.(*IntrospectTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// IDPTokenService_ServiceDesc is the grpc.ServiceDesc for IDPTokenService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var IDPTokenService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "idp_token.IDPTokenService", + HandlerType: (*IDPTokenServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "CreateToken", + Handler: _IDPTokenService_CreateToken_Handler, + }, + { + MethodName: "IntrospectToken", + Handler: _IDPTokenService_IntrospectToken_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "idp_token.proto", +} diff --git a/idptoken/provider.go b/idptoken/provider.go new file mode 100644 index 0000000..80f44a5 --- /dev/null +++ b/idptoken/provider.go @@ -0,0 +1,692 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken + +import ( + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "math/big" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/acronis/go-appkit/log" + "golang.org/x/sync/singleflight" + + "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" +) + +const ( + defaultMinRefreshPeriod = time.Second * 10 + defaultExpirationOffset = time.Minute * 30 + expiryDeltaMaxOffset = 5 + wellKnownPath = "/.well-known/openid-configuration" +) + +var ( + // ErrSourceNotRegistered is returned if GetToken is requested for the unknown Source + ErrSourceNotRegistered = errors.New("cannot issue token for unknown source") +) + +// UnexpectedIDPResponseError is an error representing an unexpected response +type UnexpectedIDPResponseError struct { + HTTPCode int + IssueURL string +} + +func (e *UnexpectedIDPResponseError) Error() string { + return fmt.Sprintf(`%s responded with unexpected code %d`, e.IssueURL, e.HTTPCode) +} + +// TokenData represents API-related token information +type TokenData struct { + Data string + ClientID string + issueURL string + Scope []string + Expires time.Time +} + +// Source serves to provide auth source information to MultiSourceProvider and Provider +type Source struct { + URL string + ClientID string + ClientSecret string +} + +var zeroTime = time.Time{} + +// TokenDetails represents the data to be stored in TokenCache +type TokenDetails struct { + token TokenData + requestedScope []string + sourceURL string + issued time.Time + nextRefresh time.Time + invalidation time.Time +} + +// TokenCache is a cache entry used to store TokenDetails based on a string key +type TokenCache interface { + // Get returns a value from the cache by key. + Get(key string) *TokenDetails + + // Put sets a new value to the cache by key. + Put(key string, val *TokenDetails) + + // Delete removes a value from the cache by key. + Delete(key string) + + // ClearAll removes all values from the cache. + ClearAll() + + // Keys returns all keys from the cache. + Keys() []string +} + +type InMemoryTokenCache struct { + mu sync.RWMutex + items map[string]*TokenDetails +} + +func NewInMemoryTokenCache() *InMemoryTokenCache { + return &InMemoryTokenCache{items: make(map[string]*TokenDetails)} +} + +func (c *InMemoryTokenCache) Keys() []string { + c.mu.RLock() + defer c.mu.RUnlock() + result := make([]string, 0, len(c.items)) + for k := range c.items { + result = append(result, k) + } + return result +} + +func (c *InMemoryTokenCache) Get(key string) *TokenDetails { + c.mu.RLock() + defer c.mu.RUnlock() + item, found := c.items[key] + if !found { + return nil + } + return item +} + +func (c *InMemoryTokenCache) Put(key string, val *TokenDetails) { + c.mu.Lock() + defer c.mu.Unlock() + c.items[key] = val +} + +func (c *InMemoryTokenCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.items, key) +} + +func (c *InMemoryTokenCache) ClearAll() { + c.mu.Lock() + defer c.mu.Unlock() + c.items = make(map[string]*TokenDetails) +} + +// MultiSourceProvider is a caching token provider for multiple datacenters and clients +type MultiSourceProvider struct { + tokenIssuers map[string]*oauth2Issuer + + rescheduleSignal chan struct{} + minRefreshPeriod time.Duration + httpClient *http.Client + logger log.FieldLogger + promMetrics *metrics.PrometheusMetrics + + cache TokenCache + customHeaders map[string]string + sfGroup singleflight.Group + + nextRefreshMu sync.RWMutex + nextRefresh time.Time +} + +// NewMultiSourceProviderWithOpts returns a new instance of MultiSourceProvider with custom settings +func NewMultiSourceProviderWithOpts( + httpClient *http.Client, opts ProviderOpts, sources ...Source, +) *MultiSourceProvider { + p := MultiSourceProvider{} + + if opts.Logger == nil { + opts.Logger = log.NewDisabledLogger() + } + + if opts.MinRefreshPeriod == 0 { + opts.MinRefreshPeriod = defaultMinRefreshPeriod + } + + p.init(httpClient, opts, sources...) + return &p +} + +// NewMultiSourceProvider returns a new instance of MultiSourceProvider with default settings +func NewMultiSourceProvider(httpClient *http.Client) *MultiSourceProvider { + return NewMultiSourceProviderWithOpts(httpClient, ProviderOpts{}) +} + +// RegisterSource allows registering a new Source into MultiSourceProvider +func (p *MultiSourceProvider) RegisterSource(source Source) { + key := keyForIssuer(source.ClientID, source.URL) + if iss, found := p.tokenIssuers[key]; found { + if iss.clientSecret != source.ClientSecret { + iss.clientSecret = source.ClientSecret + p.cache.ClearAll() + } + } + newIssuer := p.newOAuth2Issuer(source.URL, source.ClientID, source.ClientSecret) + p.tokenIssuers[keyForIssuer(source.ClientID, source.URL)] = newIssuer +} + +// GetToken returns raw token for `clientID`, `sourceURL` and `scope` +func (p *MultiSourceProvider) GetToken( + ctx context.Context, clientID, sourceURL string, scope ...string, +) (string, error) { + return p.GetTokenWithHeaders(ctx, clientID, sourceURL, nil, scope...) +} + +// GetTokenWithHeaders returns raw token for `clientID`, `sourceURL` and `scope` while using `headers` +func (p *MultiSourceProvider) GetTokenWithHeaders( + ctx context.Context, clientID, sourceURL string, headers map[string]string, scope ...string, +) (string, error) { + return p.ensureToken(ctx, clientID, sourceURL, headers, scope) +} + +// Invalidate fully invalidates all tokens cache +func (p *MultiSourceProvider) Invalidate() { + p.cache.ClearAll() + p.setNextRefreshSafe(zeroTime) +} + +// RefreshTokensPeriodically starts a goroutine which refreshes tokens +func (p *MultiSourceProvider) RefreshTokensPeriodically(ctx context.Context) { + p.refreshLoop(ctx) +} + +func (p *MultiSourceProvider) issueToken( + ctx context.Context, clientID, sourceURL string, customHeaders map[string]string, scope []string, +) (TokenData, error) { + issuer, found := p.tokenIssuers[keyForIssuer(clientID, sourceURL)] + + if !found { + return TokenData{}, ErrSourceNotRegistered + } + + headers := make(map[string]string) + for k := range p.customHeaders { + headers[k] = p.customHeaders[k] + } + for k := range customHeaders { + headers[k] = customHeaders[k] + } + + _, errEns, _ := p.sfGroup.Do(keyForIssuer(clientID, sourceURL), func() (interface{}, error) { + return nil, issuer.EnsureIssuerURL(ctx, headers) + }) + + if errEns != nil { + p.logger.Error(fmt.Sprintf("(%s, %s): ensure issuer URL", sourceURL, clientID), log.Error(errEns)) + return TokenData{}, errEns + } + + sortedScope := uniqAndSort(scope) + key := keyForCache(clientID, issuer.loadIssuerURL(), sortedScope) + + token, err, _ := p.sfGroup.Do(key, func() (interface{}, error) { + result, issErr := issuer.IssueToken(ctx, headers, sortedScope) + p.cacheToken(result, sourceURL) + return result, issErr + }) + + if err != nil { + p.logger.Error(fmt.Sprintf("(%s, %s): issuing token", issuer.loadIssuerURL(), clientID), log.Error(err)) + return TokenData{}, err + } + + return token.(TokenData), nil +} + +func (p *MultiSourceProvider) ensureToken( + ctx context.Context, clientID, sourceURL string, customHeaders map[string]string, scope []string, +) (string, error) { + token, err := p.getCachedOrInvalidate(clientID, sourceURL, scope) + if err == nil { + return token.Data, nil + } + p.logger.Infof("(%s, %s): could not get token from cache: %v", sourceURL, clientID, err.Error()) + + token, err = p.issueToken(ctx, clientID, sourceURL, customHeaders, scope) + + if err != nil { + return "", err + } + return token.Data, nil +} + +func (p *MultiSourceProvider) cacheToken(token TokenData, sourceURL string) { + issued := time.Now().UTC() + randInt, err := rand.Int(rand.Reader, big.NewInt(expiryDeltaMaxOffset)) + if err != nil { + p.logger.Error("rand init error", log.Error(err)) + return + } + deltaMinutes := time.Minute * time.Duration(randInt.Int64()) + realExpiration := token.Expires.Sub(issued) + refreshDuration := token.Expires.Sub(issued) - defaultExpirationOffset - deltaMinutes + if realExpiration < defaultExpirationOffset { + refreshDuration = realExpiration / 5 + } + + nextRefresh := issued.Add(refreshDuration) + invalidation := issued.Add(realExpiration) + details := &TokenDetails{ + token: token, + issued: issued, + nextRefresh: nextRefresh, + invalidation: invalidation, + sourceURL: sourceURL, + } + + key := keyForCache(token.ClientID, token.issueURL, uniqAndSort(token.Scope)) + p.cache.Put(key, details) + pNextRefresh := p.getNextRefreshSafe() + if pNextRefresh == zeroTime || nextRefresh.UnixNano() <= pNextRefresh.UnixNano() { + p.setNextRefreshSafe(nextRefresh) + select { + case p.rescheduleSignal <- struct{}{}: + default: + } + } +} + +func (p *MultiSourceProvider) getCachedOrInvalidate(clientID, sourceURL string, scope []string) (TokenData, error) { + now := time.Now().UnixNano() + issuer, found := p.tokenIssuers[keyForIssuer(clientID, sourceURL)] + if !found { + return TokenData{}, fmt.Errorf("(%s, %s): not registered", sourceURL, clientID) + } + if issuer.loadIssuerURL() == "" { + return TokenData{}, fmt.Errorf("(%s, %s): issuer URL not acquired", sourceURL, clientID) + } + + key := keyForCache(clientID, issuer.loadIssuerURL(), uniqAndSort(scope)) + details := p.cache.Get(key) + if details == nil { + return TokenData{}, errors.New("token not found in cache") + } + if details.token.Expires.UnixNano() < now { + p.cache.Delete(key) + return TokenData{}, errors.New("token is expired") + } + if details.invalidation.UnixNano() < now { + p.cache.Delete(key) + return TokenData{}, errors.New("token needs to be refreshed") + } + if details.issued.UnixNano() > now { + p.cache.Delete(key) + return TokenData{}, errors.New("token's issued time is invalid") + } + return details.token, nil +} + +func (p *MultiSourceProvider) setNextRefreshSafe(nextRefresh time.Time) { + p.nextRefreshMu.Lock() + p.nextRefresh = nextRefresh + p.nextRefreshMu.Unlock() +} + +func (p *MultiSourceProvider) getNextRefreshSafe() time.Time { + p.nextRefreshMu.RLock() + nextRefresh := p.nextRefresh + p.nextRefreshMu.RUnlock() + return nextRefresh +} + +func (p *MultiSourceProvider) refreshTokens(ctx context.Context) { + now := time.Now().UTC() + + resultMap := make(map[*TokenDetails]struct{}) + nextRefresh := zeroTime + for _, key := range p.cache.Keys() { + details := p.cache.Get(key) + if details == nil { + continue + } + if details.nextRefresh.UnixNano() <= now.UnixNano() { + resultMap[details] = struct{}{} + continue + } + if nextRefresh == zeroTime { + nextRefresh = details.nextRefresh + } + if details.nextRefresh.UnixNano() <= nextRefresh.UnixNano() { + nextRefresh = details.nextRefresh + } + } + p.setNextRefreshSafe(nextRefresh) + toRefresh := make([]*TokenDetails, 0, len(resultMap)) + for token := range resultMap { + toRefresh = append(toRefresh, token) + } + + for _, details := range toRefresh { + _, err := p.issueToken(ctx, details.token.ClientID, details.sourceURL, nil, details.requestedScope) + if err != nil { + p.setNextRefreshSafe(now) + p.logger.Error( + fmt.Sprintf("(%s, %s): refresh error", details.sourceURL, details.token.ClientID), log.Error(err), + ) + } + } +} + +func (p *MultiSourceProvider) refreshLoop(ctx context.Context) { + t := time.NewTimer(time.Hour) + if !t.Stop() { + <-t.C + } + stopped := true + lastRefresh := time.Now().UTC() + currentRefresh := zeroTime + scheduleNext := func() { + nextRefresh := p.getNextRefreshSafe() + + currentRefresh = nextRefresh + if nextRefresh == zeroTime { + stopped = true + return + } + + now := time.Now().UTC() + next := nextRefresh.Sub(now) + if nextRefresh.Sub(lastRefresh) < p.minRefreshPeriod { + next = lastRefresh.Add(p.minRefreshPeriod).Sub(now) + } + + stopped = false + t.Reset(next) + } + scheduleNext() + for { + select { + case <-t.C: + lastRefresh = time.Now().UTC() + p.refreshTokens(ctx) + scheduleNext() + case <-p.rescheduleSignal: + nextRefresh := p.getNextRefreshSafe() + + if currentRefresh != nextRefresh { + if !stopped && !t.Stop() { + <-t.C + } + + if stopped { + // Token was issued a moment ago. + lastRefresh = time.Now().UTC() + } + + scheduleNext() + } + case <-ctx.Done(): + if !stopped && !t.Stop() { + <-t.C + } + return + } + } +} + +func (p *MultiSourceProvider) init(httpClient *http.Client, opts ProviderOpts, sources ...Source) { + if httpClient == nil { + panic("httpClient is mandatory") + } + p.cache = opts.CustomCacheInstance + if p.cache == nil { + p.cache = NewInMemoryTokenCache() + } + p.rescheduleSignal = make(chan struct{}, 1) + p.nextRefresh = zeroTime + p.minRefreshPeriod = opts.MinRefreshPeriod + p.logger = opts.Logger + p.httpClient = httpClient + p.tokenIssuers = make(map[string]*oauth2Issuer) + p.promMetrics = metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider") + p.customHeaders = opts.CustomHeaders + + for _, source := range sources { + p.RegisterSource(source) + } +} + +// Provider is a caching token provider for a single credentials set +type Provider struct { + provider *MultiSourceProvider + source Source +} + +// NewProvider returns a new instance of Provider +func NewProvider(httpClient *http.Client, source Source) *Provider { + return NewProviderWithOpts(httpClient, ProviderOpts{}, source) +} + +// NewProviderWithOpts returns a new instance of Provider with custom options +func NewProviderWithOpts(httpClient *http.Client, opts ProviderOpts, source Source) *Provider { + mp := Provider{ + source: source, + provider: NewMultiSourceProviderWithOpts(httpClient, opts, source), + } + return &mp +} + +// RefreshTokensPeriodically starts a goroutine which refreshes tokens +func (mp *Provider) RefreshTokensPeriodically(ctx context.Context) { + mp.provider.RefreshTokensPeriodically(ctx) +} + +// GetToken returns raw token for `scope` +func (mp *Provider) GetToken( + ctx context.Context, scope ...string, +) (string, error) { + return mp.provider.GetToken(ctx, mp.source.ClientID, mp.source.URL, scope...) +} + +// GetTokenWithHeaders returns raw token for `scope` while using `headers` +func (mp *Provider) GetTokenWithHeaders( + ctx context.Context, headers map[string]string, scope ...string, +) (string, error) { + return mp.provider.GetTokenWithHeaders(ctx, mp.source.ClientID, mp.source.URL, headers, scope...) +} + +func (mp *Provider) Invalidate() { + mp.provider.Invalidate() +} + +type oauth2Issuer struct { + baseURL string + clientID string + clientSecret string + httpClient *http.Client + logger log.FieldLogger + issuerURL atomic.Value + promMetrics *metrics.PrometheusMetrics +} + +func (p *MultiSourceProvider) newOAuth2Issuer(baseURL, clientID, clientSecret string) *oauth2Issuer { + return &oauth2Issuer{ + baseURL: baseURL, + clientID: clientID, + clientSecret: clientSecret, + httpClient: p.httpClient, + logger: p.logger, + promMetrics: p.promMetrics, + } +} + +func (ti *oauth2Issuer) loadIssuerURL() string { + if v := ti.issuerURL.Load(); v != nil { + return v.(string) + } + return "" +} + +func (ti *oauth2Issuer) EnsureIssuerURL(ctx context.Context, customHeaders map[string]string) error { + if ti.loadIssuerURL() != "" { + return nil + } + + openIDCfgURL := strings.TrimSuffix(ti.baseURL, "/") + wellKnownPath + openIDCfg, err := idputil.GetOpenIDConfiguration( + ctx, ti.httpClient, openIDCfgURL, customHeaders, ti.logger, ti.promMetrics) + if err != nil { + return fmt.Errorf("(%s, %s): get OpenID configuration: %w", ti.baseURL, ti.clientID, err) + } + + if _, err = url.ParseRequestURI(openIDCfg.TokenURL); err != nil { + return fmt.Errorf("(%s, %s): issuer have returned a non-valid URL %q: %w", + ti.baseURL, ti.clientID, openIDCfg.TokenURL, err) + } + ti.issuerURL.Store(openIDCfg.TokenURL) + return nil +} + +func (ti *oauth2Issuer) IssueToken( + ctx context.Context, customHeaders map[string]string, scope []string, +) (TokenData, error) { + issuerURL := ti.loadIssuerURL() + if issuerURL == "" { + panic("must first ensure issuerURL") + } + values := url.Values{} + values.Add("grant_type", "client_credentials") + scopeStr := strings.Join(scope, " ") + if scopeStr != "" { + values.Add("scope", scopeStr) + } + req, reqErr := http.NewRequest(http.MethodPost, issuerURL, strings.NewReader(values.Encode())) + if reqErr != nil { + return TokenData{}, reqErr + } + req = req.WithContext(ctx) + req.SetBasicAuth(ti.clientID, ti.clientSecret) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + for key := range customHeaders { + req.Header.Add(key, customHeaders[key]) + } + start := time.Now() + resp, err := ti.httpClient.Do(req) + elapsed := time.Since(start) + + if err != nil { + ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, issuerURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return TokenData{}, fmt.Errorf("do http request: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + ti.logger.Error( + fmt.Sprintf("(%s, %s): closing body", ti.loadIssuerURL(), ti.clientID), log.Error(err), + ) + } + }() + + tokenResponse := tokenResponseBody{} + if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + ti.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, issuerURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return TokenData{}, fmt.Errorf( + "(%s, %s): read and unmarshal IDP response: %w", ti.loadIssuerURL(), ti.clientID, err, + ) + } + + if resp.StatusCode != http.StatusOK { + ti.promMetrics.ObserveHTTPClientRequest( + http.MethodPost, issuerURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return TokenData{}, &UnexpectedIDPResponseError{HTTPCode: resp.StatusCode, IssueURL: ti.loadIssuerURL()} + } + + ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, issuerURL, resp.StatusCode, elapsed, "") + expires := time.Now().Add(time.Second * time.Duration(tokenResponse.ExpiresIn)) + ti.logger.Infof("(%s, %s): issued token, expires on %s", ti.loadIssuerURL(), ti.clientID, expires.UTC()) + return TokenData{ + Data: tokenResponse.AccessToken, + Scope: scope, + Expires: expires, + issueURL: ti.loadIssuerURL(), + ClientID: ti.clientID, + }, nil +} + +// ProviderOpts represents options for creating a new MultiSourceProvider +type ProviderOpts struct { + // Logger is a logger for MultiSourceProvider. + Logger log.FieldLogger + + // MinRefreshPeriod is a minimal possible refresh interval for MultiSourceProvider's token cache. + MinRefreshPeriod time.Duration + + // CustomHeaders is a map of custom headers to be used in all HTTP requests. + CustomHeaders map[string]string + + // CustomCacheInstance is a custom token cache instance to be used in MultiSourceProvider. + CustomCacheInstance TokenCache + + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same service. + PrometheusLibInstanceLabel string +} + +func uniqAndSort(s []string) []string { + uniq := make(map[string]struct{}) + for ix := range s { + uniq[s[ix]] = struct{}{} + } + result := make([]string, 0, len(uniq)) + for k := range uniq { + result = append(result, k) + } + sort.Strings(result) + return result +} + +func keyForCache(clientID, sourceURL string, scope []string) string { + return clientID + ":" + sourceURL + ":" + strings.Join(scope, ",") +} + +func keyForIssuer(clientID, sourceURL string) string { + return sourceURL + ":" + clientID +} + +type tokenResponseBody struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope,omitempty"` + // not empty if token scope is different + // from the requested scope. Is equal to + // serialized token scope claim. Returned + // explicitly so client can know token + // scope w/o token parsing. Useful for + // middleware token response processing + ExpiresIn int `json:"expires_in"` + + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} diff --git a/idptoken/provider_test.go b/idptoken/provider_test.go new file mode 100644 index 0000000..2f1a572 --- /dev/null +++ b/idptoken/provider_test.go @@ -0,0 +1,464 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idptoken_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/acronis/go-appkit/httpclient" + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/testutil" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/metrics" + "github.com/acronis/go-authkit/jwt" +) + +const ( + expectedUserAgent = "Token MultiSourceProvider/1.0" + expectedXRequestID = "test" + testClientID = "89cadd1f-8649-4531-8b1d-a25de5aa3cd6" + defaultTestTokenExpirationTime = 2 +) + +type tTokenResponseBody struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope,omitempty"` + ExpiresIn int `json:"expires_in"` + Error string `json:"error"` +} + +type tFailingIDPTokenHandler struct{} + +func (h *tFailingIDPTokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + response := tTokenResponseBody{ + Error: "server_error", + } + encoder := json.NewEncoder(rw) + err := encoder.Encode(response) + if err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +type tHeaderCheckingIDPTokenHandler struct { + t *testing.T +} + +func (h *tHeaderCheckingIDPTokenHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + require.Equal(h.t, expectedUserAgent, r.Header.Get("User-Agent")) + require.Equal(h.t, expectedXRequestID, r.Header.Get("X-Request-ID")) + rw.WriteHeader(http.StatusOK) + response := tTokenResponseBody{ + AccessToken: "success", + ExpiresIn: defaultTestTokenExpirationTime, + Scope: "tenants:viewer", + } + encoder := json.NewEncoder(rw) + err := encoder.Encode(response) + if err != nil { + http.Error(rw, fmt.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } +} + +func TestProviderWithCache(t *testing.T) { + tr, _ := httpclient.NewRetryableRoundTripperWithOpts( + http.DefaultTransport, httpclient.RetryableRoundTripperOpts{MaxRetryAttempts: 3}, + ) + httpClient := &http.Client{Transport: tr} + logger := log.NewDisabledLogger() + + t.Run("custom headers", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPTokenHandler(&tHeaderCheckingIDPTokenHandler{t}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + CustomHeaders: map[string]string{"User-Agent": expectedUserAgent}, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetTokenWithHeaders( + context.Background(), testClientID, server.URL(), + map[string]string{"X-Request-ID": expectedXRequestID}, "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("get token", func(t *testing.T) { + const tokenTTL = 2 * time.Second + + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: tokenTTL}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + cachedToken, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + + newToken, newTokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, newTokenErr) + require.Equal(t, cachedToken, newToken, "token was not cached") + time.Sleep(tokenTTL * 2) + + reissuedToken, reissuedTokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, reissuedTokenErr) + require.Greater(t, reissuedToken, cachedToken, "token was not re-issued") + }) + + t.Run("automatic refresh", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + + tokenOld, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + time.Sleep(3 * time.Second) + token, refreshErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, refreshErr) + require.Greater(t, token, tokenOld, "token should have already been refreshed") + }) + + t.Run("invalidate", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 10 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 10 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + + tokenOld, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + provider.Invalidate() + time.Sleep(1 * time.Second) + token, refreshErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, refreshErr) + require.Greater(t, token, tokenOld, "token should have already been refreshed") + }) + + t.Run("failing idp endpoint", func(t *testing.T) { + server := idptest.NewHTTPServer(idptest.WithHTTPTokenHandler(&tFailingIDPTokenHandler{})) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL() + "/weird", + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.Error(t, tokenErr) + labels := prometheus.Labels{ + metrics.HTTPClientRequestLabelMethod: http.MethodPost, + metrics.HTTPClientRequestLabelURL: server.URL() + idptest.TokenEndpointPath, + metrics.HTTPClientRequestLabelStatusCode: "500", + metrics.HTTPClientRequestLabelError: "unexpected_status_code", + } + promMetrics := metrics.GetPrometheusMetrics("", "token_provider") + hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram) + testutil.AssertSamplesCountInHistogram(t, hist, 1) + }) + + t.Run("metrics", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, + ClientSecret: "DAGztV5L2hMZyECzer6SXS", + URL: server.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken(context.Background(), testClientID, server.URL(), "tenants:read") + require.NoError(t, tokenErr) + labels := prometheus.Labels{ + metrics.HTTPClientRequestLabelMethod: http.MethodPost, + metrics.HTTPClientRequestLabelURL: server.URL() + idptest.TokenEndpointPath, + metrics.HTTPClientRequestLabelStatusCode: "200", + metrics.HTTPClientRequestLabelError: "", + } + promMetrics := metrics.GetPrometheusMetrics("", "token_provider") + hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram) + testutil.AssertSamplesCountInHistogram(t, hist, 1) + }) + + t.Run("multiple sources", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + server2 := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + idptest.WithHTTPAddress(":8082"), + ) + require.NoError(t, server2.StartAndWaitForReady(time.Second)) + defer func() { _ = server2.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + }, + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXs", URL: server2.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + + _, tokenErr = provider.GetToken( + context.Background(), testClientID, server2.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("multiple sources", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + server2 := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + idptest.WithHTTPAddress(":8082"), + ) + require.NoError(t, server2.StartAndWaitForReady(time.Second)) + defer func() { _ = server2.Shutdown(context.Background()) }() + + credentials := []idptoken.Source{ + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + }, + { + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXs", URL: server2.URL(), + }, + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials[0]) + go provider.RefreshTokensPeriodically(context.Background()) + provider.RegisterSource(credentials[1]) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server2.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("single source provider", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := idptoken.Source{ + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + } + opts := idptoken.ProviderOpts{ + Logger: logger, + MinRefreshPeriod: 1 * time.Second, + } + provider := idptoken.NewProviderWithOpts(httpClient, opts, credentials) + go provider.RefreshTokensPeriodically(context.Background()) + _, tokenErr := provider.GetToken(context.Background(), "tenants:read") + require.NoError(t, tokenErr) + }) + + t.Run("start with no sources and register later", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := idptoken.Source{ + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + } + provider := idptoken.NewMultiSourceProvider(httpClient) + go provider.RefreshTokensPeriodically(context.Background()) + provider.RegisterSource(credentials) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + }) + + t.Run("register source twice", func(t *testing.T) { + server := idptest.NewHTTPServer( + idptest.WithHTTPClaimsProvider(&claimsProviderWithExpiration{ExpTime: 2 * time.Second}), + ) + require.NoError(t, server.StartAndWaitForReady(time.Second)) + defer func() { _ = server.Shutdown(context.Background()) }() + + credentials := idptoken.Source{ + ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), + } + tokenCache := idptoken.NewInMemoryTokenCache() + provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, idptoken.ProviderOpts{CustomCacheInstance: tokenCache}) + go provider.RefreshTokensPeriodically(context.Background()) + provider.RegisterSource(credentials) + credentials.ClientSecret = "newsecret" + provider.RegisterSource(credentials) + _, tokenErr := provider.GetToken( + context.Background(), testClientID, server.URL(), "tenants:read", + ) + require.NoError(t, tokenErr) + provider.RegisterSource(credentials) + require.Equal(t, 1, len(tokenCache.Keys()), "updating with same secret does not reset the cache") + credentials.ClientSecret = "evennewersecret" + provider.RegisterSource(credentials) + require.Equal(t, 0, len(tokenCache.Keys()), "updating with a new secret does reset the cache") + }) +} + +type claimsProviderWithExpiration struct { + ExpTime time.Duration +} + +func (d *claimsProviderWithExpiration) Provide(_ *http.Request) jwt.Claims { + claims := jwt.Claims{ + // nolint:staticcheck // StandardClaims are used here for test purposes + RegisteredClaims: jwtgo.RegisteredClaims{ + ID: uuid.NewString(), + IssuedAt: jwtgo.NewNumericDate(time.Now().UTC()), + }, + Scope: []jwt.AccessPolicy{ + { + TenantID: "1", + TenantUUID: uuid.NewString(), + Role: "tenant:viewer", + }, + }, + Version: 1, + UserID: "1", + } + + if d.ExpTime <= 0 { + d.ExpTime = 24 * time.Hour + } + claims.ExpiresAt = jwtgo.NewNumericDate(time.Now().UTC().Add(d.ExpTime)) + + return claims +} diff --git a/internal/idputil/doc.go b/internal/idputil/doc.go new file mode 100644 index 0000000..fc03f59 --- /dev/null +++ b/internal/idputil/doc.go @@ -0,0 +1,9 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package idputil provides utilities for working with identity providers. +// It's used in the internal code and not exposed to the public API. +package idputil diff --git a/internal/idputil/openid_configuration.go b/internal/idputil/openid_configuration.go new file mode 100644 index 0000000..370bb08 --- /dev/null +++ b/internal/idputil/openid_configuration.go @@ -0,0 +1,72 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idputil + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/acronis/go-appkit/log" + + "github.com/acronis/go-authkit/internal/metrics" +) + +type OpenIDConfiguration struct { + TokenURL string `json:"token_endpoint"` + IntrospectionEndpoint string `json:"introspection_endpoint"` + JWKSURI string `json:"jwks_uri"` +} + +func GetOpenIDConfiguration( + ctx context.Context, + httpClient *http.Client, + targetURL string, + additionalHeaders map[string]string, + logger log.FieldLogger, + promMetrics *metrics.PrometheusMetrics, +) (OpenIDConfiguration, error) { + req, err := http.NewRequest(http.MethodGet, targetURL, http.NoBody) + if err != nil { + return OpenIDConfiguration{}, fmt.Errorf("new request: %w", err) + } + for key, val := range additionalHeaders { + req.Header.Set(key, val) + } + + startTime := time.Now() + resp, err := httpClient.Do(req.WithContext(ctx)) + elapsed := time.Since(startTime) + if err != nil { + promMetrics.ObserveHTTPClientRequest(http.MethodGet, targetURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return OpenIDConfiguration{}, fmt.Errorf("do request: %w", err) + } + defer func() { + if closeBodyErr := resp.Body.Close(); closeBodyErr != nil && logger != nil { + logger.Error(fmt.Sprintf("closing response body error for GET %s", targetURL), log.Error(closeBodyErr)) + } + }() + + if resp.StatusCode != http.StatusOK { + promMetrics.ObserveHTTPClientRequest( + http.MethodGet, targetURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return OpenIDConfiguration{}, fmt.Errorf("unexpected HTTP code %d", resp.StatusCode) + } + + var openIDCfg OpenIDConfiguration + if err = json.NewDecoder(resp.Body).Decode(&openIDCfg); err != nil { + promMetrics.ObserveHTTPClientRequest( + http.MethodGet, targetURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return OpenIDConfiguration{}, fmt.Errorf("decode response body json (Content-Type: %s): %w", + resp.Header.Get("Content-Type"), err) + } + + promMetrics.ObserveHTTPClientRequest(http.MethodGet, targetURL, resp.StatusCode, elapsed, "") + return openIDCfg, nil +} diff --git a/internal/idputil/trusted_issuers_store.go b/internal/idputil/trusted_issuers_store.go new file mode 100644 index 0000000..d7e1fd1 --- /dev/null +++ b/internal/idputil/trusted_issuers_store.go @@ -0,0 +1,80 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idputil + +import ( + "fmt" + "net/url" + "sync" + + "github.com/vasayxtx/go-glob" +) + +type TrustedIssuerURLMatcher func(issURL *url.URL) bool + +type TrustedIssuerStore struct { + mu sync.RWMutex + issuers map[string]string + issuerURLMatchers []TrustedIssuerURLMatcher +} + +func NewTrustedIssuerStore() *TrustedIssuerStore { + return &TrustedIssuerStore{ + issuers: make(map[string]string), + } +} + +func (s *TrustedIssuerStore) AddTrustedIssuer(issName, issURL string) { + s.mu.Lock() + s.issuers[issName] = issURL + s.mu.Unlock() +} + +func (s *TrustedIssuerStore) AddTrustedIssuerURL(issURL string) error { + s.mu.Lock() + defer s.mu.Unlock() + urlMatcher, err := makeTrustedIssuerURLMatcher(issURL) + if err != nil { + return err + } + s.issuerURLMatchers = append(s.issuerURLMatchers, urlMatcher) + return nil +} + +func (s *TrustedIssuerStore) GetURLForIssuer(issuer string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + if issuerURL, ok := s.issuers[issuer]; ok { + return issuerURL, true + } + + parsedIssURL, err := url.Parse(issuer) + if err != nil { + return "", false + } + for i := range s.issuerURLMatchers { + if s.issuerURLMatchers[i](parsedIssURL) { + return issuer, true + } + } + + return "", false +} + +func makeTrustedIssuerURLMatcher(urlPattern string) (TrustedIssuerURLMatcher, error) { + parsedURL, err := url.Parse(urlPattern) + if err != nil { + return nil, fmt.Errorf("parse issuer URL glob pattern: %w", err) + } + hostMatcher := glob.Compile(parsedURL.Host) + return func(issURL *url.URL) bool { + return hostMatcher(issURL.Host) && + parsedURL.Path == issURL.Path && + parsedURL.Scheme == issURL.Scheme && + parsedURL.RawQuery == issURL.RawQuery + }, nil +} diff --git a/internal/libinfo/doc.go b/internal/libinfo/doc.go new file mode 100644 index 0000000..7087edf --- /dev/null +++ b/internal/libinfo/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package libinfo provides helpers for working with the library information. +package libinfo diff --git a/internal/libinfo/lib_info.go b/internal/libinfo/lib_info.go new file mode 100644 index 0000000..b08d088 --- /dev/null +++ b/internal/libinfo/lib_info.go @@ -0,0 +1,26 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package libinfo + +import ( + "fmt" +) + +const libName = "go-authkit" + +const libPath = "github.com/acronis/" + libName + +func MakeUserAgent(prependedUserAgent string) string { + if prependedUserAgent != "" { + prependedUserAgent += " " + } + return prependedUserAgent + " " + libName + "/" + GetLibVersion() +} + +func GetLogPrefix() string { + return fmt.Sprintf("[%s/%s]", libName, GetLibVersion()) +} diff --git a/internal/libinfo/version.go b/internal/libinfo/version.go new file mode 100644 index 0000000..a533b5b --- /dev/null +++ b/internal/libinfo/version.go @@ -0,0 +1,33 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package libinfo + +import ( + "sync" + + "runtime/debug" +) + +var libVersion string +var libVersionOnce sync.Once + +func initLibVersion() { + if buildInfo, ok := debug.ReadBuildInfo(); ok && buildInfo != nil { + for _, dep := range buildInfo.Deps { + if dep.Path == libPath { + libVersion = dep.Version + return + } + } + } + libVersion = "v0.0.0" +} + +func GetLibVersion() string { + libVersionOnce.Do(initLibVersion) + return libVersion +} diff --git a/internal/metrics/doc.go b/internal/metrics/doc.go new file mode 100644 index 0000000..6f9a198 --- /dev/null +++ b/internal/metrics/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package metrics provides helpers for working with the library metrics. +package metrics diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..e7cf519 --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,174 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package metrics + +import ( + "strconv" + "sync" + "time" + + "github.com/acronis/go-appkit/lrucache" + "github.com/prometheus/client_golang/prometheus" + grpccodes "google.golang.org/grpc/codes" + + "github.com/acronis/go-authkit/internal/libinfo" +) + +const PrometheusNamespace = "go_authkit" + +const DefaultPrometheusLibInstanceLabel = "default" + +const ( + PrometheusLibInstanceLabel = "lib_instance" + PrometheusLibSourceLabel = "lib_source" +) + +func PrometheusLabels() prometheus.Labels { + return prometheus.Labels{"lib_version": libinfo.GetLibVersion()} +} + +const ( + HTTPClientRequestLabelMethod = "method" + HTTPClientRequestLabelURL = "url" + HTTPClientRequestLabelStatusCode = "status_code" + HTTPClientRequestLabelError = "error" + + GRPCClientRequestLabelMethod = "grpc_method" + GRPCClientRequestLabelCode = "grpc_code" +) + +const ( + HTTPRequestErrorDo = "do_request_error" + HTTPRequestErrorDecodeBody = "decode_body_error" + HTTPRequestErrorUnexpectedStatusCode = "unexpected_status_code" +) + +var requestDurationBuckets = []float64{0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} + +var ( + prometheusMetrics *PrometheusMetrics + prometheusMetricsOnce sync.Once +) + +// PrometheusMetrics represents the collector of metrics. +type PrometheusMetrics struct { + HTTPClientRequestDuration *prometheus.HistogramVec + GRPCClientRequestDuration *prometheus.HistogramVec + TokenClaimsCache *lrucache.PrometheusMetrics + TokenNegativeCache *lrucache.PrometheusMetrics +} + +func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics { + prometheusMetricsOnce.Do(func() { + prometheusMetrics = newPrometheusMetrics() + prometheusMetrics.MustRegister() + }) + if instance == "" { + instance = DefaultPrometheusLibInstanceLabel + } + return prometheusMetrics.MustCurryWith(map[string]string{ + PrometheusLibInstanceLabel: instance, + PrometheusLibSourceLabel: source, + }) +} + +func newPrometheusMetrics() *PrometheusMetrics { + curriedLabelNames := []string{PrometheusLibInstanceLabel, PrometheusLibSourceLabel} + makeLabelNames := func(names ...string) []string { + l := append(make([]string, 0, len(curriedLabelNames)+len(names)), curriedLabelNames...) + return append(l, names...) + } + + httpClientReqDuration := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: PrometheusNamespace, + Name: "http_client_request_duration_seconds", + Help: "A histogram of the http client request durations to IDP endpoints.", + Buckets: requestDurationBuckets, + ConstLabels: PrometheusLabels(), + }, + makeLabelNames(HTTPClientRequestLabelMethod, HTTPClientRequestLabelURL, + HTTPClientRequestLabelStatusCode, HTTPClientRequestLabelError), + ) + grpcClientReqDuration := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: PrometheusNamespace, + Name: "grpc_client_request_duration_seconds", + Help: "A histogram of the grpc client request durations to IDP endpoints.", + Buckets: requestDurationBuckets, + ConstLabels: PrometheusLabels(), + }, + makeLabelNames(GRPCClientRequestLabelMethod, GRPCClientRequestLabelCode), + ) + + tokenClaimsCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{ + Namespace: PrometheusNamespace + "_token_claims", + ConstLabels: PrometheusLabels(), + CurriedLabelNames: curriedLabelNames, + }) + + tokenNegativeCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{ + Namespace: PrometheusNamespace + "_token_negative", + ConstLabels: PrometheusLabels(), + CurriedLabelNames: curriedLabelNames, + }) + + return &PrometheusMetrics{ + HTTPClientRequestDuration: httpClientReqDuration, + GRPCClientRequestDuration: grpcClientReqDuration, + TokenClaimsCache: tokenClaimsCache, + TokenNegativeCache: tokenNegativeCache, + } +} + +// MustCurryWith curries the metrics collector with the provided labels. +func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *PrometheusMetrics { + return &PrometheusMetrics{ + HTTPClientRequestDuration: pm.HTTPClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), + GRPCClientRequestDuration: pm.GRPCClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), + TokenClaimsCache: pm.TokenClaimsCache.MustCurryWith(labels), + TokenNegativeCache: pm.TokenNegativeCache.MustCurryWith(labels), + } +} + +// MustRegister does registration of metrics collector in Prometheus and panics if any error occurs. +func (pm *PrometheusMetrics) MustRegister() { + prometheus.MustRegister( + pm.HTTPClientRequestDuration, + pm.GRPCClientRequestDuration, + ) + pm.TokenClaimsCache.MustRegister() + pm.TokenNegativeCache.MustRegister() +} + +// Unregister cancels registration of metrics collector in Prometheus. +func (pm *PrometheusMetrics) Unregister() { + prometheus.Unregister(pm.HTTPClientRequestDuration) + prometheus.Unregister(pm.GRPCClientRequestDuration) + pm.TokenClaimsCache.Unregister() + pm.TokenNegativeCache.Unregister() +} + +func (pm *PrometheusMetrics) ObserveHTTPClientRequest( + method string, targetURL string, statusCode int, elapsed time.Duration, errorType string, +) { + pm.HTTPClientRequestDuration.With(prometheus.Labels{ + HTTPClientRequestLabelMethod: method, + HTTPClientRequestLabelURL: targetURL, + HTTPClientRequestLabelStatusCode: strconv.Itoa(statusCode), + HTTPClientRequestLabelError: errorType, + }).Observe(elapsed.Seconds()) +} + +func (pm *PrometheusMetrics) ObserveGRPCClientRequest( + method string, code grpccodes.Code, elapsed time.Duration, +) { + pm.GRPCClientRequestDuration.With(prometheus.Labels{ + GRPCClientRequestLabelMethod: method, + GRPCClientRequestLabelCode: code.String(), + }).Observe(elapsed.Seconds()) +} diff --git a/internal/testing/doc.go b/internal/testing/doc.go new file mode 100644 index 0000000..a516b74 --- /dev/null +++ b/internal/testing/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package testing provides internal testing utilities. +package testing diff --git a/internal/testing/server_token_introspector_mock.go b/internal/testing/server_token_introspector_mock.go new file mode 100644 index 0000000..9f4b247 --- /dev/null +++ b/internal/testing/server_token_introspector_mock.go @@ -0,0 +1,152 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package testing + +import ( + "context" + "crypto/sha256" + "net/http" + "net/url" + + "google.golang.org/grpc/metadata" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/jwt" +) + +type JWTParser interface { + Parse(ctx context.Context, token string) (*jwt.Claims, error) +} + +type HTTPServerTokenIntrospectorMock struct { + JWTParser JWTParser + + introspectionResults map[[sha256.Size]byte]idptoken.IntrospectionResult + jwtScopes map[string][]jwt.AccessPolicy + + Called bool + LastAuthorizationHeader string + LastIntrospectedToken string + LastFormValues url.Values +} + +func NewHTTPServerTokenIntrospectorMock() *HTTPServerTokenIntrospectorMock { + return &HTTPServerTokenIntrospectorMock{ + introspectionResults: make(map[[sha256.Size]byte]idptoken.IntrospectionResult), + jwtScopes: make(map[string][]jwt.AccessPolicy), + } +} + +func (m *HTTPServerTokenIntrospectorMock) SetResultForToken(token string, result idptoken.IntrospectionResult) { + m.introspectionResults[tokenToKey(token)] = result +} + +func (m *HTTPServerTokenIntrospectorMock) SetScopeForJWTID(jwtID string, scope []jwt.AccessPolicy) { + m.jwtScopes[jwtID] = scope +} + +func (m *HTTPServerTokenIntrospectorMock) IntrospectToken(r *http.Request, token string) idptoken.IntrospectionResult { + m.Called = true + m.LastAuthorizationHeader = r.Header.Get("Authorization") + m.LastIntrospectedToken = token + m.LastFormValues = r.Form + + if result, ok := m.introspectionResults[tokenToKey(token)]; ok { + return result + } + + claims, err := m.JWTParser.Parse(r.Context(), token) + if err != nil { + return idptoken.IntrospectionResult{Active: false} + } + result := idptoken.IntrospectionResult{Active: true, TokenType: idptoken.TokenTypeBearer, Claims: *claims} + if scopes, ok := m.jwtScopes[claims.ID]; ok { + result.Scope = scopes + } + return result +} + +func (m *HTTPServerTokenIntrospectorMock) ResetCallsInfo() { + m.Called = false + m.LastAuthorizationHeader = "" + m.LastIntrospectedToken = "" + m.LastFormValues = nil +} + +type GRPCServerTokenIntrospectorMock struct { + JWTParser JWTParser + + introspectionResults map[[sha256.Size]byte]*pb.IntrospectTokenResponse + scopes map[string][]*pb.AccessTokenScope + + Called bool + LastAuthorizationMeta string + LastRequest *pb.IntrospectTokenRequest +} + +func NewGRPCServerTokenIntrospectorMock() *GRPCServerTokenIntrospectorMock { + return &GRPCServerTokenIntrospectorMock{ + introspectionResults: make(map[[sha256.Size]byte]*pb.IntrospectTokenResponse), + scopes: make(map[string][]*pb.AccessTokenScope), + } +} + +func (m *GRPCServerTokenIntrospectorMock) SetResultForToken(token string, result *pb.IntrospectTokenResponse) { + m.introspectionResults[tokenToKey(token)] = result +} + +func (m *GRPCServerTokenIntrospectorMock) SetScopeForJWTID(jwtID string, scope []*pb.AccessTokenScope) { + m.scopes[jwtID] = scope +} + +func (m *GRPCServerTokenIntrospectorMock) IntrospectToken( + ctx context.Context, req *pb.IntrospectTokenRequest, +) (*pb.IntrospectTokenResponse, error) { + m.Called = true + if mdVal := metadata.ValueFromIncomingContext(ctx, "authorization"); len(mdVal) != 0 { + m.LastAuthorizationMeta = mdVal[0] + } else { + m.LastAuthorizationMeta = "" + } + m.LastRequest = req + + if result, ok := m.introspectionResults[tokenToKey(req.Token)]; ok { + return result, nil + } + + claims, err := m.JWTParser.Parse(ctx, req.Token) + if err != nil { + return &pb.IntrospectTokenResponse{Active: false}, nil + } + result := &pb.IntrospectTokenResponse{ + Active: true, + TokenType: idptoken.TokenTypeBearer, + Exp: claims.ExpiresAt.Unix(), + Aud: claims.Audience, + Jti: claims.ID, + Iss: claims.Issuer, + Sub: claims.Subject, + SubType: claims.SubType, + ClientId: claims.ClientID, + OwnerTenantUuid: claims.OwnerTenantUUID, + } + if scopes, ok := m.scopes[claims.ID]; ok { + result.Scope = scopes + } + return result, nil +} + +func (m *GRPCServerTokenIntrospectorMock) ResetCallsInfo() { + m.Called = false + m.LastAuthorizationMeta = "" + m.LastRequest = nil +} + +func tokenToKey(token string) [sha256.Size]byte { + return sha256.Sum256([]byte(token)) +} diff --git a/jwks/caching_client.go b/jwks/caching_client.go new file mode 100644 index 0000000..6954618 --- /dev/null +++ b/jwks/caching_client.go @@ -0,0 +1,169 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/lrucache" +) + +const DefaultCacheUpdateMinInterval = time.Minute * 1 + +// CachingClientOpts contains options for CachingClient. +type CachingClientOpts struct { + ClientOpts + + // CacheUpdateMinInterval is a minimal interval between cache updates for the same issuer. + CacheUpdateMinInterval time.Duration +} + +// CachingClient is a Client for getting keys from remote JWKS with a caching mechanism. +type CachingClient struct { + mu sync.RWMutex + rawClient *Client + issuerCache map[string]issuerCacheEntry + cacheUpdateMinInterval time.Duration +} + +const missingKeysCacheSize = 100 + +type issuerCacheEntry struct { + updatedAt time.Time + keys map[string]interface{} + missingKeys *lrucache.LRUCache[string, time.Time] +} + +// NewCachingClient returns a new Client that can cache fetched data. +func NewCachingClient(httpClient *http.Client, logger log.FieldLogger) *CachingClient { + return NewCachingClientWithOpts(httpClient, logger, CachingClientOpts{}) +} + +// NewCachingClientWithOpts returns a new Client that can cache fetched data with options. +func NewCachingClientWithOpts(httpClient *http.Client, logger log.FieldLogger, opts CachingClientOpts) *CachingClient { + if opts.CacheUpdateMinInterval == 0 { + opts.CacheUpdateMinInterval = DefaultCacheUpdateMinInterval + } + return &CachingClient{ + rawClient: NewClientWithOpts(httpClient, logger, opts.ClientOpts), + issuerCache: make(map[string]issuerCacheEntry), + cacheUpdateMinInterval: opts.CacheUpdateMinInterval, + } +} + +// GetRSAPublicKey searches JWK with passed key ID in JWKS and returns decoded RSA public key for it. +// The last one can be used for verifying JWT signature. Obtained JWKS is cached. +// If passed issuer URL or key ID is not found in the cache, JWKS will be fetched again, +// but not more than once in a some (configurable) period of time. +func (cc *CachingClient) GetRSAPublicKey(ctx context.Context, issuerURL, keyID string) (interface{}, error) { + pubKey, found, needInvalidate := cc.getPubKeyFromCache(issuerURL, keyID) + if found { + return pubKey, nil + } + if needInvalidate { + var err error + if pubKey, found, err = cc.getPubKeyFromCacheAndInvalidate(ctx, issuerURL, keyID); err != nil || found { + return pubKey, err + } + } + return nil, &JWKNotFoundError{IssuerURL: issuerURL, KeyID: keyID} +} + +// InvalidateCacheIfNeeded does cache invalidation for specific issuer URL if it's necessary. +func (cc *CachingClient) InvalidateCacheIfNeeded(ctx context.Context, issuerURL string) error { + cc.mu.Lock() + defer cc.mu.Unlock() + + var missingKeys *lrucache.LRUCache[string, time.Time] + issCache, found := cc.issuerCache[issuerURL] + if found { + if time.Since(issCache.updatedAt) < cc.cacheUpdateMinInterval { + return nil + } + missingKeys = issCache.missingKeys + } else { + var err error + if missingKeys, err = lrucache.New[string, time.Time](missingKeysCacheSize, nil); err != nil { + return fmt.Errorf("new lru cache for missing keys: %w", err) + } + } + + pubKeys, err := cc.rawClient.getRSAPubKeysForIssuer(ctx, issuerURL) + if err != nil { + return fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err) + } + cc.issuerCache[issuerURL] = issuerCacheEntry{ + updatedAt: time.Now(), + keys: pubKeys, + missingKeys: missingKeys, + } + return nil +} + +func (cc *CachingClient) getPubKeyFromCache( + issuerURL, keyID string, +) (pubKey interface{}, found bool, needInvalidate bool) { + cc.mu.RLock() + defer cc.mu.RUnlock() + + issCache, issFound := cc.issuerCache[issuerURL] + if !issFound { + return nil, false, true + } + if pubKey, found = issCache.keys[keyID]; found { + return + } + missedAt, miss := issCache.missingKeys.Get(keyID) + if !miss || time.Since(missedAt) > cc.cacheUpdateMinInterval { + return nil, false, true + } + return nil, false, false +} + +func (cc *CachingClient) getPubKeyFromCacheAndInvalidate( + ctx context.Context, issuerURL, keyID string, +) (pubKey interface{}, found bool, err error) { + cc.mu.Lock() + defer cc.mu.Unlock() + + var missingKeys *lrucache.LRUCache[string, time.Time] + if issCache, issFound := cc.issuerCache[issuerURL]; issFound { + if pubKey, found = issCache.keys[keyID]; found { + return pubKey, true, nil + } + missedAt, miss := issCache.missingKeys.Get(keyID) + if miss && time.Since(missedAt) < cc.cacheUpdateMinInterval { + return nil, false, nil + } + missingKeys = issCache.missingKeys + } else { + missingKeys, err = lrucache.New[string, time.Time](missingKeysCacheSize, nil) + if err != nil { + return nil, false, fmt.Errorf("new lru cache for missing keys: %w", err) + } + } + + pubKeys, err := cc.rawClient.getRSAPubKeysForIssuer(ctx, issuerURL) + if err != nil { + return nil, false, fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err) + } + pubKey, found = pubKeys[keyID] + if !found { + missingKeys.Add(keyID, time.Now()) + } + cc.issuerCache[issuerURL] = issuerCacheEntry{ + updatedAt: time.Now(), + keys: pubKeys, + missingKeys: missingKeys, + } + return pubKey, found, nil +} diff --git a/jwks/caching_client_test.go b/jwks/caching_client_test.go new file mode 100644 index 0000000..395bbba --- /dev/null +++ b/jwks/caching_client_test.go @@ -0,0 +1,110 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks_test + +import ( + "context" + "crypto/rsa" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" +) + +func TestCachingClient_GetRSAPublicKey(t *testing.T) { + t.Run("ok", func(t *testing.T) { + jwksHandler := &idptest.JWKSHandler{} + jwksServer := httptest.NewServer(jwksHandler) + defer jwksServer.Close() + issuerConfigHandler := &idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL} + issuerConfigServer := httptest.NewServer(issuerConfigHandler) + defer issuerConfigServer.Close() + + cachingClient := jwks.NewCachingClientWithOpts(http.DefaultClient, log.NewDisabledLogger(), + jwks.CachingClientOpts{CacheUpdateMinInterval: time.Second * 10}) + var wg sync.WaitGroup + const callsNum = 10 + wg.Add(callsNum) + errs := make(chan error, callsNum) + pubKeys := make(chan interface{}, callsNum) + for i := 0; i < callsNum; i++ { + go func() { + defer wg.Done() + pubKey, err := cachingClient.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + if err != nil { + errs <- err + return + } + pubKeys <- pubKey + }() + } + wg.Wait() + close(errs) + close(pubKeys) + for err := range errs { + require.NoError(t, err) + } + for pubKey := range pubKeys { + require.NotNil(t, pubKey) + require.IsType(t, &rsa.PublicKey{}, pubKey) + } + require.EqualValues(t, 1, issuerConfigHandler.ServedCount()) + require.EqualValues(t, 1, jwksHandler.ServedCount()) + }) + + t.Run("jwk not found", func(t *testing.T) { + jwksHandler := &idptest.JWKSHandler{} + jwksServer := httptest.NewServer(jwksHandler) + defer jwksServer.Close() + issuerConfigHandler := &idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL} + issuerConfigServer := httptest.NewServer(issuerConfigHandler) + defer issuerConfigServer.Close() + + const unknownKeyID = "77777777-7777-7777-7777-777777777777" + const cacheUpdateMinInterval = time.Second * 1 + + cachingClient := jwks.NewCachingClientWithOpts(http.DefaultClient, log.NewDisabledLogger(), + jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) + + doGetPublicKeyByUnknownID := func(callsNum int) { + t.Helper() + var wg sync.WaitGroup + wg.Add(callsNum) + for i := 0; i < callsNum; i++ { + go func() { + defer wg.Done() + pubKey, err := cachingClient.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, unknownKeyID) + require.Error(t, err) + var jwkErr *jwks.JWKNotFoundError + require.True(t, errors.As(err, &jwkErr)) + require.Equal(t, issuerConfigServer.URL, jwkErr.IssuerURL) + require.Equal(t, unknownKeyID, jwkErr.KeyID) + require.Nil(t, pubKey) + }() + } + wg.Wait() + } + + doGetPublicKeyByUnknownID(10) + require.EqualValues(t, 1, issuerConfigHandler.ServedCount()) + require.EqualValues(t, 1, jwksHandler.ServedCount()) + + time.Sleep(cacheUpdateMinInterval * 2) + + doGetPublicKeyByUnknownID(10) + require.EqualValues(t, 2, issuerConfigHandler.ServedCount()) + require.EqualValues(t, 2, jwksHandler.ServedCount()) + }) +} diff --git a/jwks/client.go b/jwks/client.go new file mode 100644 index 0000000..9af6075 --- /dev/null +++ b/jwks/client.go @@ -0,0 +1,138 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks + +import ( + "context" + "crypto" + "crypto/rsa" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/acronis/go-appkit/log" + "github.com/mendsley/gojwk" + + "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" +) + +const OpenIDConfigurationPath = "/.well-known/openid-configuration" + +type jwksData struct { + Keys []*gojwk.Key `json:"keys"` +} + +// ClientOpts contains options for the JWKS client. +type ClientOpts struct { + // PrometheusLibInstanceLabel is a label for Prometheus metrics. + // It allows distinguishing metrics from different instances of the same library. + PrometheusLibInstanceLabel string +} + +// Client gets public keys from remote JWKS. +// It uses jwks_uri field from /.well-known/openid-configuration endpoint. +// NOTE: CachingClient should be used in a typical service +// to avoid making HTTP requests on each JWT verification. +type Client struct { + httpClient *http.Client + logger log.FieldLogger + promMetrics *metrics.PrometheusMetrics +} + +// NewClient returns a new Client. +func NewClient(httpClient *http.Client, logger log.FieldLogger) *Client { + return NewClientWithOpts(httpClient, logger, ClientOpts{}) +} + +// NewClientWithOpts returns a new Client with options. +func NewClientWithOpts(httpClient *http.Client, logger log.FieldLogger, opts ClientOpts) *Client { + promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client") + return &Client{httpClient, logger, promMetrics} +} + +func (c *Client) getRSAPubKeysForIssuer(ctx context.Context, issuerURL string) (map[string]interface{}, error) { + openIDConfigURL := strings.TrimPrefix(issuerURL, "/") + OpenIDConfigurationPath + openIDConfig, err := idputil.GetOpenIDConfiguration( + ctx, c.httpClient, openIDConfigURL, nil, c.logger, c.promMetrics) + if err != nil { + return nil, &GetOpenIDConfigurationError{Inner: err, URL: openIDConfigURL} + } + jwksRespData, err := c.getJWKS(ctx, openIDConfig.JWKSURI) + if err != nil { + return nil, &GetJWKSError{Inner: err, URL: openIDConfig.JWKSURI, OpenIDConfigurationURL: openIDConfigURL} + } + c.logger.Info(fmt.Sprintf("%d keys fetched (jwks_url: %s)", len(jwksRespData.Keys), openIDConfig.JWKSURI)) + + pubKeys := make(map[string]interface{}, len(jwksRespData.Keys)) + for _, jwk := range jwksRespData.Keys { + var pubKey crypto.PublicKey + if pubKey, err = jwk.DecodePublicKey(); err != nil { + c.logger.Error(fmt.Sprintf("decoding JWK (kid: %s, jwks_url: %s) to public key error", + jwk.Kid, openIDConfig.JWKSURI), log.Error(err)) + continue + } + rsaPubKey, ok := pubKey.(*rsa.PublicKey) + if !ok { + c.logger.Error(fmt.Sprintf("converting JWK (kid: %s, jwks_url: %s) to RSA public key error", + jwk.Kid, openIDConfig.JWKSURI), log.Error(err)) + continue + } + pubKeys[jwk.Kid] = rsaPubKey + } + return pubKeys, nil +} + +// GetRSAPublicKey gets JWK from JWKS and returns decoded RSA public key. The last one can be used for verifying JWT signature. +func (c *Client) GetRSAPublicKey(ctx context.Context, issuerURL, keyID string) (interface{}, error) { + pubKeys, err := c.getRSAPubKeysForIssuer(ctx, issuerURL) + if err != nil { + return nil, fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err) + } + pubKey, ok := pubKeys[keyID] + if !ok { + return nil, &JWKNotFoundError{IssuerURL: issuerURL, KeyID: keyID} + } + return pubKey, nil +} + +func (c *Client) getJWKS(ctx context.Context, jwksURL string) (jwksData, error) { + req, err := http.NewRequest(http.MethodGet, jwksURL, http.NoBody) + if err != nil { + return jwksData{}, fmt.Errorf("new request: %w", err) + } + startTime := time.Now() + resp, err := c.httpClient.Do(req.WithContext(ctx)) + elapsed := time.Since(startTime) + if err != nil { + c.promMetrics.ObserveHTTPClientRequest(http.MethodGet, jwksURL, 0, elapsed, metrics.HTTPRequestErrorDo) + return jwksData{}, fmt.Errorf("do request: %w", err) + } + defer func() { + if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { + c.logger.Error(fmt.Sprintf("closing response body error for GET %s", jwksURL), log.Error(closeBodyErr)) + } + }() + + if resp.StatusCode != http.StatusOK { + c.promMetrics.ObserveHTTPClientRequest( + http.MethodGet, jwksURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return jwksData{}, fmt.Errorf("unexpected HTTP code %d", resp.StatusCode) + } + + var res jwksData + if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { + c.promMetrics.ObserveHTTPClientRequest( + http.MethodGet, jwksURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + return jwksData{}, fmt.Errorf("decode response body json: %w", err) + } + + c.promMetrics.ObserveHTTPClientRequest(http.MethodGet, jwksURL, resp.StatusCode, elapsed, "") + return res, nil +} diff --git a/jwks/client_test.go b/jwks/client_test.go new file mode 100644 index 0000000..96e3525 --- /dev/null +++ b/jwks/client_test.go @@ -0,0 +1,160 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks_test + +import ( + "context" + "crypto/rsa" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/acronis/go-appkit/log" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" +) + +func TestClient_GetRSAPublicKey(t *testing.T) { + t.Run("ok", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.NoError(t, err) + require.NotNil(t, pubKey) + require.IsType(t, &rsa.PublicKey{}, pubKey) + }) + + t.Run("issuer openid configuration unavailable", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + issuerConfigServer.Close() // Close the server immediately. + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.ErrorContains(t, openIDCfgErr.Inner, "connection refused") + require.Nil(t, pubKey) + }) + + t.Run("openid configuration server respond internal error", func(t *testing.T) { + issuerConfigServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + })) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.EqualError(t, openIDCfgErr.Inner, "unexpected HTTP code 500") + require.Nil(t, pubKey) + }) + + t.Run("openid configuration server respond invalid json", func(t *testing.T) { + issuerConfigServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, err := rw.Write([]byte(`{"invalid-json"]`)) + require.NoError(t, err) + })) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + var jsonSyntaxErr *json.SyntaxError + require.True(t, errors.As(openIDCfgErr, &jsonSyntaxErr)) + require.Nil(t, pubKey) + }) + + t.Run("jwks server unavailable", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + jwksServer.Close() // Close the server immediately. + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var jwksErr *jwks.GetJWKSError + require.True(t, errors.As(err, &jwksErr)) + require.Equal(t, jwksServer.URL, jwksErr.URL) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) + require.ErrorContains(t, jwksErr.Inner, "connection refused") + require.Nil(t, pubKey) + }) + + t.Run("jwks server respond internal error", func(t *testing.T) { + jwksServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + })) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var jwksErr *jwks.GetJWKSError + require.True(t, errors.As(err, &jwksErr)) + require.Equal(t, jwksServer.URL, jwksErr.URL) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) + require.EqualError(t, jwksErr.Inner, "unexpected HTTP code 500") + require.Nil(t, pubKey) + }) + + t.Run("jwk not found", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + const unknownKeyID = "77777777-7777-7777-7777-777777777777" + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, unknownKeyID) + require.Error(t, err) + var jwkErr *jwks.JWKNotFoundError + require.True(t, errors.As(err, &jwkErr)) + require.Equal(t, issuerConfigServer.URL, jwkErr.IssuerURL) + require.Equal(t, unknownKeyID, jwkErr.KeyID) + require.Nil(t, pubKey) + }) + + t.Run("context canceled", func(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + ctx, cancelCtxFn := context.WithCancel(context.Background()) + cancelCtxFn() // Emulate canceling context. + pubKey, err := client.GetRSAPublicKey(ctx, issuerConfigServer.URL, idptest.TestKeyID) + require.Error(t, err) + var openIDCfgErr *jwks.GetOpenIDConfigurationError + require.True(t, errors.As(err, &openIDCfgErr)) + require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.ErrorIs(t, openIDCfgErr, context.Canceled) + require.Nil(t, pubKey) + }) +} diff --git a/jwks/doc.go b/jwks/doc.go new file mode 100644 index 0000000..56650b7 --- /dev/null +++ b/jwks/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package jwks contains clients for getting public keys from JWKS. +package jwks diff --git a/jwks/errors.go b/jwks/errors.go new file mode 100644 index 0000000..bfe9e45 --- /dev/null +++ b/jwks/errors.go @@ -0,0 +1,49 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwks + +import "fmt" + +// GetOpenIDConfigurationError is an error that may occur during getting openID configuration for issuer. +type GetOpenIDConfigurationError struct { + Inner error + URL string +} + +func (e *GetOpenIDConfigurationError) Error() string { + return fmt.Sprintf("error while getting OpenID configuration (URL: %q): %s", e.URL, e.Inner.Error()) +} + +func (e *GetOpenIDConfigurationError) Unwrap() error { + return e.Inner +} + +// GetJWKSError is an error that may occur during getting JWKS. +type GetJWKSError struct { + Inner error + URL string + OpenIDConfigurationURL string +} + +func (e *GetJWKSError) Error() string { + return fmt.Sprintf("error while getting JWKS data (URL: %q, OpenID configuration URL: %q): %s", + e.URL, e.OpenIDConfigurationURL, e.Inner.Error()) +} + +func (e *GetJWKSError) Unwrap() error { + return e.Inner +} + +// JWKNotFoundError is an error that occurs when JWK is not found by kid. +type JWKNotFoundError struct { + IssuerURL string + KeyID string +} + +func (e *JWKNotFoundError) Error() string { + return fmt.Sprintf("JWK not found (Key ID: %q, Issuer URL: %q)", e.KeyID, e.IssuerURL) +} diff --git a/jwt/caching_parser.go b/jwt/caching_parser.go new file mode 100644 index 0000000..f0cfdc0 --- /dev/null +++ b/jwt/caching_parser.go @@ -0,0 +1,113 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import ( + "context" + "crypto/sha256" + "fmt" + "unsafe" + + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/lrucache" + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/acronis/go-authkit/internal/metrics" +) + +const DefaultClaimsCacheMaxEntries = 1000 + +type CachingParserOpts struct { + ParserOpts + CacheMaxEntries int + CachePrometheusInstanceLabel string +} + +// ClaimsCache is an interface that must be implemented by used cache implementations. +type ClaimsCache interface { + Get(key [sha256.Size]byte) (*Claims, bool) + Add(key [sha256.Size]byte, value *Claims) + Purge() + Len() int +} + +// CachingParser uses the functionality of Parser to parse JWT, but stores resulted Claims objects in the cache. +type CachingParser struct { + *Parser + ClaimsCache ClaimsCache +} + +func NewCachingParser(keysProvider KeysProvider, logger log.FieldLogger) (*CachingParser, error) { + return NewCachingParserWithOpts(keysProvider, logger, CachingParserOpts{}) +} + +func NewCachingParserWithOpts( + keysProvider KeysProvider, logger log.FieldLogger, opts CachingParserOpts, +) (*CachingParser, error) { + promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, "jwt_parser") + if opts.CacheMaxEntries == 0 { + opts.CacheMaxEntries = DefaultClaimsCacheMaxEntries + } + cache, err := lrucache.New[[sha256.Size]byte, *Claims](opts.CacheMaxEntries, promMetrics.TokenClaimsCache) + if err != nil { + return nil, err + } + return &CachingParser{ + Parser: NewParserWithOpts(keysProvider, logger, opts.ParserOpts), + ClaimsCache: cache, + }, nil +} + +// getTokenHash converts an access token to a string hash that is used as a cache key. +func getTokenHash(token []byte) [sha256.Size]byte { + return sha256.Sum256(token) +} + +// stringToBytesUnsafe converts string to byte slice without memory allocation. (both heap and stack) +func stringToBytesUnsafe(s string) []byte { + // nolint: gosec // memory optimization to prevent redundant slice copying + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// Parse calls Parse method of embedded original Parser but stores result into cache. +func (cp *CachingParser) Parse(ctx context.Context, token string) (*Claims, error) { + key := getTokenHash(stringToBytesUnsafe(token)) + cachedClaims, foundInCache, validationErr := cp.getFromCacheAndValidateIfNeeded(key) + if foundInCache { + if validationErr != nil { + return nil, validationErr + } + return cachedClaims, nil + } + claims, err := cp.Parser.Parse(ctx, token) + if err != nil { + return nil, err + } + cp.ClaimsCache.Add(key, claims) + return claims, nil +} + +func (cp *CachingParser) getFromCacheAndValidateIfNeeded(key [sha256.Size]byte) (claims *Claims, found bool, err error) { + cachedClaims, found := cp.ClaimsCache.Get(key) + if !found { + return nil, false, nil + } + if !cp.Parser.skipClaimsValidation { + if err = cp.Parser.claimsValidator.Validate(cachedClaims); err != nil { + return nil, true, fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidClaims, err) + } + if err = cp.Parser.customValidator(cachedClaims); err != nil { + return nil, true, fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidClaims, err) + } + } + return cachedClaims, true, nil +} + +// InvalidateClaimsCache removes all preserved parsed Claims objects from cache. +func (cp *CachingParser) InvalidateClaimsCache() { + cp.ClaimsCache.Purge() +} diff --git a/jwt/caching_parser_test.go b/jwt/caching_parser_test.go new file mode 100644 index 0000000..f0ede58 --- /dev/null +++ b/jwt/caching_parser_test.go @@ -0,0 +1,109 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt_test + +import ( + "context" + "crypto/sha256" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +func getTokenHash(token []byte) [sha256.Size]byte { + tokenCheckSum := sha256.Sum256(token) + return tokenCheckSum +} + +func TestGetTokenHash(t *testing.T) { + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString := []byte(idptest.MustMakeTokenStringSignedWithTestKey(claims)) + + th := getTokenHash(tokenString) + require.NotEmpty(t, th, "generated token hash must not be an empty string") + th2 := getTokenHash(tokenString) + require.Equal(t, th, th2, "two hashes of the same token must be equal") + + claims2 := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: "other" + testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(12 * time.Minute))}} + tokenString2 := []byte(idptest.MustMakeTokenStringSignedWithTestKey(claims2)) + th3 := getTokenHash(tokenString2) + require.NotEqual(t, th, th3, "two hashes of different tokens must be different") +} + +func TestCachingParser_Parse(t *testing.T) { + logger := log.NewDisabledLogger() + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + parser, err := jwt.NewCachingParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, err) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + + var parsedClaims *jwt.Claims + parsedClaims, err = parser.Parse(context.Background(), tokenString) + require.NoError(t, err, "caching parser must not return error from Parse method") + require.Equal(t, claims.Scope, parsedClaims.Scope, "unexpected claims value produced by caching parser") + + require.Equal(t, 1, parser.ClaimsCache.Len(), + "one claims object must be cached after successful parse operation") + + tokenKey := getTokenHash([]byte(tokenString)) + cachedClaims, found := parser.ClaimsCache.Get(tokenKey) + require.True(t, found, "cached claims object must be found by token hash") + require.Equal(t, claims.Scope, cachedClaims.Scope, "unexpected claims value fetched from parser cache") + + parser.InvalidateClaimsCache() + require.Equal(t, 0, parser.ClaimsCache.Len(), + "parser cache must be empty after invalidation") +} + +func TestCachingParser_CheckExpiration(t *testing.T) { + const jwtTTL = 2 * time.Second + + logger := log.NewDisabledLogger() + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(jwtTTL))}} + tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) + + parser, err := jwt.NewCachingParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, err) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + + var parsedClaims *jwt.Claims + parsedClaims, err = parser.Parse(context.Background(), tokenString) + require.NoError(t, err, "caching parser must not return error from Parse method") + require.Equal(t, claims.Scope, parsedClaims.Scope, "unexpected claims value produced by caching parser") + + require.Equal(t, 1, parser.ClaimsCache.Len(), + "one claims object must be cached after successful parse operation") + + time.Sleep(jwtTTL * 2) + + parsedClaims, err = parser.Parse(context.Background(), tokenString) + require.Error(t, err, "caching parser must return error since cached jwt is expired") + require.Nil(t, parsedClaims) +} diff --git a/jwt/doc.go b/jwt/doc.go new file mode 100644 index 0000000..4280943 --- /dev/null +++ b/jwt/doc.go @@ -0,0 +1,8 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +// Package jwt provides primitives for working with JWT (Parser, Claims, and so on). +package jwt diff --git a/jwt/errors.go b/jwt/errors.go new file mode 100644 index 0000000..fb3d89d --- /dev/null +++ b/jwt/errors.go @@ -0,0 +1,56 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import ( + "fmt" +) + +// SignAlgUnknownError represents an error when JWT signing algorithm is unknown. +type SignAlgUnknownError struct { + Alg string +} + +func (e *SignAlgUnknownError) Error() string { + return fmt.Sprintf("JWT has unknown signing algorithm %q", e.Alg) +} + +// IssuerUntrustedError represents an error when JWT issuer is untrusted. +type IssuerUntrustedError struct { + Claims *Claims +} + +func (e *IssuerUntrustedError) Error() string { + return fmt.Sprintf("JWT issuer %q untrusted", e.Claims.Issuer) +} + +// IssuerMissingError represents an error when JWT issuer is missing. +type IssuerMissingError struct { + Claims *Claims +} + +func (e *IssuerMissingError) Error() string { + return "JWT issuer missing" +} + +// AudienceMissingError represents an error when JWT audience is missing, but it's required. +type AudienceMissingError struct { + Claims *Claims +} + +func (e *AudienceMissingError) Error() string { + return "JWT audience missing" +} + +// AudienceNotExpectedError represents an error when JWT contains not expected audience. +type AudienceNotExpectedError struct { + Claims *Claims +} + +func (e *AudienceNotExpectedError) Error() string { + return fmt.Sprintf("JWT audience %q not expected", e.Claims.Audience) +} diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 0000000..f1be148 --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,255 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt + +import ( + "context" + "errors" + "fmt" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/vasayxtx/go-glob" + + "github.com/acronis/go-authkit/internal/idputil" +) + +// KeysProvider is an interface for providing keys for verifying JWT. +type KeysProvider interface { + GetRSAPublicKey(ctx context.Context, issuer, keyID string) (interface{}, error) +} + +// CachingKeysProvider is an interface for providing keys for verifying JWT. +// Unlike KeysProvider, it supports caching of obtained keys. +type CachingKeysProvider interface { + KeysProvider + InvalidateCacheIfNeeded(ctx context.Context, issuer string) error +} + +// ParserOpts additional options for parser. +type ParserOpts struct { + SkipClaimsValidation bool + RequireAudience bool + ExpectedAudience []string + TrustedIssuerNotFoundFallback TrustedIssNotFoundFallback +} + +type audienceMatcher func(aud string) bool + +// TrustedIssNotFoundFallback is a function called when given issuer is not found in the list of trusted ones. +// For example, it could be analyzed and then added to the list by calling AddTrustedIssuerURL method. +type TrustedIssNotFoundFallback func(ctx context.Context, p *Parser, iss string) (issURL string, issFound bool) + +// Parser is an object for parsing, validation and verification JWT. +type Parser struct { + parser *jwtgo.Parser + claimsValidator *jwtgo.Validator + customValidator func(claims *Claims) error + skipClaimsValidation bool + keysProvider KeysProvider + + trustedIssuerStore *idputil.TrustedIssuerStore + trustedIssuerNotFoundFallback TrustedIssNotFoundFallback + + logger log.FieldLogger +} + +// NewParser creates new JWT parser with specified keys provider. +func NewParser(keysProvider KeysProvider, logger log.FieldLogger) *Parser { + return NewParserWithOpts(keysProvider, logger, ParserOpts{}) +} + +// NewParserWithOpts creates new JWT parser with specified keys provider and additional options. +func NewParserWithOpts(keysProvider KeysProvider, logger log.FieldLogger, opts ParserOpts) *Parser { + var audienceMatchers []audienceMatcher + for _, audPattern := range opts.ExpectedAudience { + audienceMatchers = append(audienceMatchers, glob.Compile(audPattern)) + } + return &Parser{ + parser: jwtgo.NewParser(jwtgo.WithExpirationRequired()), + claimsValidator: jwtgo.NewValidator(jwtgo.WithExpirationRequired()), + customValidator: makeCustomAudienceValidator(opts.RequireAudience, audienceMatchers), + skipClaimsValidation: opts.SkipClaimsValidation, + keysProvider: keysProvider, + trustedIssuerStore: idputil.NewTrustedIssuerStore(), + trustedIssuerNotFoundFallback: opts.TrustedIssuerNotFoundFallback, + logger: logger, + } +} + +// AddTrustedIssuer adds trusted issuer with specified name and URL. +func (p *Parser) AddTrustedIssuer(issName, issURL string) { + p.trustedIssuerStore.AddTrustedIssuer(issName, issURL) +} + +// AddTrustedIssuerURL adds trusted issuer URL. +func (p *Parser) AddTrustedIssuerURL(issURL string) error { + return p.trustedIssuerStore.AddTrustedIssuerURL(issURL) +} + +// GetURLForIssuer returns URL for issuer if it is trusted. +func (p *Parser) GetURLForIssuer(issuer string) (string, bool) { + return p.trustedIssuerStore.GetURLForIssuer(issuer) +} + +// Parse parses, validates and verifies passed token (it's string representation). Parsed claims is returned. +func (p *Parser) Parse(ctx context.Context, token string) (*Claims, error) { + keyFunc := p.getKeyFunc(ctx) + claims := validatableClaims{customValidator: p.customValidator} + if _, err := p.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + if !errors.Is(err, jwtgo.ErrTokenSignatureInvalid) { + return nil, err + } + + // If keys provider supports caching, we may try to invalidate it and try parsing JWT again. + cachingKeysProvider, ok := p.keysProvider.(CachingKeysProvider) + if !ok { + return nil, err + } + + issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, claims.Issuer) + if !issuerURLFound { + return nil, err + } + if err = cachingKeysProvider.InvalidateCacheIfNeeded(ctx, issuerURL); err != nil { + p.logger.Error(fmt.Sprintf("keys provider invalidating cache error for issuer %q", issuerURL), + log.Error(err)) + return nil, err + } + + if _, err = p.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + return nil, err + } + } + + return &claims.Claims, nil +} + +func (p *Parser) getKeyFunc(ctx context.Context) func(token *jwtgo.Token) (interface{}, error) { + return func(token *jwtgo.Token) (i interface{}, err error) { + switch signAlg := token.Method.Alg(); signAlg { + case "none": //nolint:goconst + return nil, jwtgo.NoneSignatureTypeDisallowedError + + case "RS256", "RS384", "RS512": + // Empty kid is LEGAL, not all IDP impl support kid. + kidStr := "" + if kid, found := token.Header["kid"]; found { + kidStr = kid.(string) + } + claims := token.Claims.(*validatableClaims) + if claims.Issuer == "" { + return nil, &IssuerMissingError{&claims.Claims} + } + issuerURL, issuerURLFound := p.getURLForIssuerWithCallback(ctx, claims.Issuer) + if !issuerURLFound { + return nil, &IssuerUntrustedError{&claims.Claims} + } + return p.keysProvider.GetRSAPublicKey(ctx, issuerURL, kidStr) + + default: + return nil, &SignAlgUnknownError{signAlg} + } + } +} + +func (p *Parser) getURLForIssuerWithCallback(ctx context.Context, issuer string) (string, bool) { + issURL, issFound := p.GetURLForIssuer(issuer) + if issFound { + return issURL, true + } + if p.trustedIssuerNotFoundFallback == nil { + return "", false + } + return p.trustedIssuerNotFoundFallback(ctx, p, issuer) +} + +// Claims represents an extended version of JWT claims. +type Claims struct { + jwtgo.RegisteredClaims + Scope []AccessPolicy `json:"scope,omitempty"` + Version int `json:"ver,omitempty"` + UserID string `json:"uid,omitempty"` + OriginID string `json:"origin,omitempty"` + ClientID string `json:"client_id,omitempty"` + TOTPTime int64 `json:"totp_time,omitempty"` + SubType string `json:"sub_type,omitempty"` + OwnerTenantUUID string `json:"owner_tuid,omitempty"` +} + +// AccessPolicy represents a single access policy. +type AccessPolicy struct { + // TenantID equals to tenant ID for which access is granted (if resource is not specified) + // or which resource is owned by (if resource is specified). + // max length is 36 characters (uuid) + TenantID string `json:"tid,omitempty"` + + // TenantUUID equals to tenant UUID for which access is granted (if resource is not specified) + // or which resource is owned by (if resource is specified). + // max length is 36 characters (uuid) + TenantUUID string `json:"tuid,omitempty"` + + // ResourceServerID must be unique resource server instance or cluster ID. + // max length is 36 characters [a-Z0-9-_] + ResourceServerID string `json:"rs,omitempty"` + + // ResourceNamespace AKA resource type, partitions resources within resource server. + // E.g.: storage, task-manager, account-server, resource-manager, policy-manager etc. + // max length is 36 characters [a-Z0-9-_] + ResourceNamespace string `json:"rn,omitempty"` + + // ResourcePath AKA resource ID AKA resource pointer, is a unique identifier of + // or path to (in scope of resource server and namespace) a single resource or resource collection + // 'path' notion remind that it can contain segments, each meaningfull to resource server + // i.e. each sub-path can correspond to different resources, and access policies can be assigned with any sub-path granularity + // but resource path will be considered as immutable, + // moving resources 'within' the path will break access control logic on both AuthZ server and resource server sides + // e.g: vms, vm1, queues, queue1 + // max length is 255 characters [a-Z0-9-_] + ResourcePath string `json:"rp,omitempty"` + + // Role - role available for the resource specified by resource id + Role string `json:"role,omitempty"` + + AllowPermissions []string `json:"allow,omitempty"` + DenyPermissions []string `json:"deny,omitempty"` +} + +type validatableClaims struct { + Claims + customValidator func(c *Claims) error +} + +func (v *validatableClaims) Validate() error { + if v.customValidator != nil { + return v.customValidator(&v.Claims) + } + return nil +} + +func makeCustomAudienceValidator(requireAudience bool, audienceMatchers []audienceMatcher) func(c *Claims) error { + return func(c *Claims) error { + if len(c.Audience) == 0 { + if requireAudience { + return fmt.Errorf("%w: %w", jwtgo.ErrTokenRequiredClaimMissing, &AudienceMissingError{c}) + } + return nil + } + + if len(audienceMatchers) == 0 { + return nil + } + for i := range audienceMatchers { + for j := range c.Audience { + if audienceMatchers[i](c.Audience[j]) { + return nil + } + } + } + return fmt.Errorf("%w: %w", jwtgo.ErrTokenInvalidAudience, &AudienceNotExpectedError{c}) + } +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 0000000..33c12a8 --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,353 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package jwt_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/acronis/go-appkit/log" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/jwks" + "github.com/acronis/go-authkit/jwt" +) + +const testIss = "test-issuer" + +func TestJWTParser_Parse(t *testing.T) { + jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer.Close() + + issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) + defer issuerConfigServer.Close() + + logger := log.NewDisabledLogger() + + t.Run("ok", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + TOTPTime: time.Now().Unix(), + SubType: "task_manager", + } + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims.Scope, parsedClaims.Scope) + require.Equal(t, claims.TOTPTime, parsedClaims.TOTPTime) + require.Equal(t, claims.SubType, parsedClaims.SubType) + }) + + t.Run("ok for empty kid", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) + }) + + t.Run("ok for trusted issuer url (glob pattern)", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: issuerConfigServer.URL, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + issURLs := []string{ + issuerConfigServer.URL, + "http://127.0.0.*", + "http://127.*", + } + for _, issURL := range issURLs { + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, parser.AddTrustedIssuerURL(issURL)) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) + } + }) + + t.Run("ok for expected audience (glob pattern)", func(t *testing.T) { + for _, aud := range []string{"region1.cloud.com", "region2.cloud.com"} { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Audience: []string{aud}, + Issuer: issuerConfigServer.URL, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + ExpectedAudience: []string{"*.cloud.com"}, + }) + require.NoError(t, parser.AddTrustedIssuerURL(issuerConfigServer.URL)) + parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, err) + require.Equal(t, claims, parsedClaims) + } + }) + + t.Run("malformed jwt", func(t *testing.T) { + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err := parser.Parse(context.Background(), "invalid-jwt") + require.ErrorIs(t, err, jwtgo.ErrTokenMalformed) + require.ErrorContains(t, err, "token contains an invalid number of segments") + }) + + t.Run("unsigned jwt", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + token := jwtgo.NewWithClaims(jwtgo.SigningMethodNone, claims) + tokenString, err := token.SignedString(jwtgo.UnsafeAllowNoneSignatureType) + require.NoError(t, err) + + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err = parser.Parse(context.Background(), tokenString) + require.ErrorIs(t, err, jwtgo.NoneSignatureTypeDisallowedError) + }) + + t.Run("jwt issuer missing", func(t *testing.T) { + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{Audience: []string{"https://cloud.acronis.com"}}, + } + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) + var issMissingErr *jwt.IssuerMissingError + require.ErrorAs(t, err, &issMissingErr) + require.Equal(t, claims, issMissingErr.Claims) + }) + + t.Run("jwt has untrusted issuer", func(t *testing.T) { + const issuer = "untrusted-issuer" + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) + var issUntrustedErr *jwt.IssuerUntrustedError + require.ErrorAs(t, err, &issUntrustedErr) + require.Equal(t, claims, issUntrustedErr.Claims) + }) + + t.Run("jwt has untrusted issuer url", func(t *testing.T) { + const issuer = "https://3rd-party-idp.com" + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, parser.AddTrustedIssuerURL("https://*.acronis.com")) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) + var issUntrustedErr *jwt.IssuerUntrustedError + require.ErrorAs(t, err, &issUntrustedErr) + require.Equal(t, claims, issUntrustedErr.Claims) + }) + + t.Run("jwt has untrusted issuer url, callback adds it to trusted", func(t *testing.T) { + var callbackCallCount int + claims := &jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Audience: []string{issuerConfigServer.URL}, + Issuer: issuerConfigServer.URL, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }, + Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, + } + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + TrustedIssuerNotFoundFallback: func(ctx context.Context, p *jwt.Parser, iss string) (issURL string, issFound bool) { + callbackCallCount++ + addErr := p.AddTrustedIssuerURL(iss) + if addErr != nil { + return "", false + } + return iss, true + }, + }) + require.Equal(t, 0, callbackCallCount) + parsedClaims, pErr := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, pErr, "issuer must be considered as trusted and no error returned") + require.Equalf(t, 1, callbackCallCount, "Callback was not called by parser") + require.Equal(t, claims, parsedClaims) + parsedClaims, pErr = parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.NoError(t, pErr, "issuer must be considered as trusted and no error returned") + require.Equal(t, claims, parsedClaims) + require.Equalf(t, 1, callbackCallCount, "Callback should be called exactly once") + }) + + t.Run("jwt exp is missing", func(t *testing.T) { + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenRequiredClaimMissing) + }) + + t.Run("jwt expired", func(t *testing.T) { + expiresAt := time.Now().Add(-time.Second) + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(expiresAt)}} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenExpired) + }) + + t.Run("jwt not valid yet", func(t *testing.T) { + notBefore := time.Now().Add(time.Minute) + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + NotBefore: jwtgo.NewNumericDate(notBefore), + }} + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenNotValidYet) + }) + + t.Run("required jwt audience is missing", func(t *testing.T) { + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }} + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + RequireAudience: true, + }) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenRequiredClaimMissing) + var jwtErr *jwt.AudienceMissingError + require.ErrorAs(t, err, &jwtErr) + require.Equal(t, claims, jwtErr.Claims) + }) + + t.Run("jwt audience is not expected", func(t *testing.T) { + const audience = "not-expected-audience" + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{ + Audience: []string{audience}, + Issuer: testIss, + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), + }} + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + ExpectedAudience: []string{"expected-audience"}, + }) + parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) + _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) + require.ErrorIs(t, err, jwtgo.ErrTokenInvalidAudience) + var jwtErr *jwt.AudienceNotExpectedError + require.ErrorAs(t, err, &jwtErr) + require.Equal(t, claims, jwtErr.Claims) + }) + + t.Run("verification error", func(t *testing.T) { + jwksServer2 := httptest.NewServer(&idptest.JWKSHandler{}) + defer jwksServer2.Close() + + openIDCfgHandler2 := &idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer2.URL} + openIDCfgServer2 := httptest.NewServer(openIDCfgHandler2) + defer openIDCfgServer2.Close() + + const cacheUpdateMinInterval = time.Second + + claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} + tokenString, err := idptest.MakeTokenString(claims, "737c5114f09b5ed05276bd4b520245982f7fb29f", idptest.GetTestRSAPrivateKey()) + require.NoError(t, err) + jwksClient := jwks.NewCachingClientWithOpts(http.DefaultClient, logger, jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) + parser := jwt.NewParser(jwksClient, logger) + parser.AddTrustedIssuer(testIss, openIDCfgServer2.URL) + + for i := 0; i < 2; i++ { + _, err = parser.Parse(context.Background(), tokenString) + require.ErrorIs(t, err, jwtgo.ErrTokenSignatureInvalid) + require.EqualValues(t, 1, openIDCfgHandler2.ServedCount()) + require.EqualValues(t, 1, openIDCfgHandler2.ServedCount()) + } + + time.Sleep(cacheUpdateMinInterval * 2) + + _, err = parser.Parse(context.Background(), tokenString) + require.ErrorIs(t, err, jwtgo.ErrTokenSignatureInvalid) + require.EqualValues(t, 2, openIDCfgHandler2.ServedCount()) + require.EqualValues(t, 2, openIDCfgHandler2.ServedCount()) + }) +} + +func TestParser_getURLForIssuer(t *testing.T) { + tests := []struct { + Name string + IssURLPattern string + TrustedIssURLs []string + NotTrustedIssURLs []string + }{ + { + Name: "wildcard in host", + IssURLPattern: "https://*.acronis.com/bc", + TrustedIssURLs: []string{ + "https://us-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com/bc", + }, + NotTrustedIssURLs: []string{ + "http://eu2-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com", + "https://eu2-cloud.acronis.com/bc/foobar", + "https://my-site.com/eu2-cloud.acronis.com/bc", + "https://my-site.com?foo=eu2-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com/bc?foo=bar", + }, + }, + { + Name: "no wildcard in path", + IssURLPattern: "https://eu3-cloud.acronis.com/bc", + TrustedIssURLs: []string{"https://eu3-cloud.acronis.com/bc"}, + NotTrustedIssURLs: []string{ + "https://eu1-cloud.acronis.com/bc", + "https://eu2-cloud.acronis.com/bc", + }, + }, + } + for i := range tests { + tt := tests[i] + t.Run(tt.Name, func(t *testing.T) { + logger := log.NewDisabledLogger() + parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + require.NoError(t, parser.AddTrustedIssuerURL(tt.IssURLPattern)) + for _, issURL := range tt.TrustedIssURLs { + u, ok := parser.GetURLForIssuer(issURL) + require.True(t, ok) + require.Equal(t, u, issURL) + } + for _, issURL := range tt.NotTrustedIssURLs { + _, ok := parser.GetURLForIssuer(issURL) + require.False(t, ok) + } + }) + } +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..d6e6f72 --- /dev/null +++ b/middleware.go @@ -0,0 +1,206 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/acronis/go-appkit/httpserver/middleware" + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/restapi" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" +) + +// HeaderAuthorization contains the name of HTTP header with data that is used for authentication and authorization. +const HeaderAuthorization = "Authorization" + +// Authentication and authorization error codes. +// We are using "var" here because some services may want to use different error codes. +var ( + ErrCodeBearerTokenMissing = "bearerTokenMissing" + ErrCodeAuthenticationFailed = "authenticationFailed" + ErrCodeAuthorizationFailed = "authorizationFailed" +) + +// Authentication error messages. +// We are using "var" here because some services may want to use different error messages. +var ( + ErrMessageBearerTokenMissing = "Authorization bearer token is missing." + ErrMessageAuthenticationFailed = "Authentication is failed." + ErrMessageAuthorizationFailed = "Authorization is failed." +) + +type ctxKey int + +const ( + ctxKeyJWTClaims ctxKey = iota + ctxKeyBearerToken +) + +// JWTParser is an interface for parsing string representation of JWT. +type JWTParser interface { + Parse(ctx context.Context, token string) (*jwt.Claims, error) +} + +// CachingJWTParser does the same as JWTParser but stores parsed JWT claims in cache. +type CachingJWTParser interface { + JWTParser + InvalidateCache(ctx context.Context) +} + +// TokenIntrospector is an interface for introspecting tokens. +type TokenIntrospector interface { + IntrospectToken(ctx context.Context, token string) (idptoken.IntrospectionResult, error) +} + +type jwtAuthHandler struct { + next http.Handler + errorDomain string + jwtParser JWTParser + verifyAccess func(r *http.Request, claims *jwt.Claims) bool + tokenIntrospector TokenIntrospector +} + +type jwtAuthMiddlewareOpts struct { + verifyAccess func(r *http.Request, claims *jwt.Claims) bool + tokenIntrospector TokenIntrospector +} + +// JWTAuthMiddlewareOption is an option for JWTAuthMiddleware. +type JWTAuthMiddlewareOption func(options *jwtAuthMiddlewareOpts) + +// WithJWTAuthMiddlewareVerifyAccess is an option to set a function that verifies access for JWTAuthMiddleware. +func WithJWTAuthMiddlewareVerifyAccess(verifyAccess func(r *http.Request, claims *jwt.Claims) bool) JWTAuthMiddlewareOption { + return func(options *jwtAuthMiddlewareOpts) { + options.verifyAccess = verifyAccess + } +} + +// WithJWTAuthMiddlewareTokenIntrospector is an option to set a token introspector for JWTAuthMiddleware. +func WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector TokenIntrospector) JWTAuthMiddlewareOption { + return func(options *jwtAuthMiddlewareOpts) { + options.tokenIntrospector = tokenIntrospector + } +} + +// JWTAuthMiddleware is a middleware that does authentication +// by Access Token from the "Authorization" HTTP header of incoming request. +func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthMiddlewareOption) func(next http.Handler) http.Handler { + var options jwtAuthMiddlewareOpts + for _, opt := range opts { + opt(&options) + } + return func(next http.Handler) http.Handler { + return &jwtAuthHandler{next, errorDomain, jwtParser, options.verifyAccess, options.tokenIntrospector} + } +} + +func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + reqCtx := r.Context() + logger := middleware.GetLoggerFromContext(reqCtx) + + bearerToken := GetBearerTokenFromRequest(r) + if bearerToken == "" { + apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + + var jwtClaims *jwt.Claims + if h.tokenIntrospector != nil { + if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil { + switch { + case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded): + // Do nothing. Access Token already contains all necessary information for authN/authZ. + case errors.Is(err, idptoken.ErrTokenNotIntrospectable): + // Token is not introspectable by some reason. + // In this case, we will parse it as JWT and use it for authZ. + if logger != nil { + logger.Warn("token is not introspectable, it will be used for authentication and authorization as is", + log.Error(err)) + } + default: + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + } else { + if !introspectionResult.Active { + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + jwtClaims = &introspectionResult.Claims + } + } + + if jwtClaims == nil { + var err error + if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil { + if logger != nil { + logger.Error("authentication failed", log.Error(err)) + } + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + return + } + } + + if h.verifyAccess != nil { + if !h.verifyAccess(r, jwtClaims) { + apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed) + restapi.RespondError(rw, http.StatusForbidden, apiErr, logger) + return + } + } + + reqCtx = NewContextWithBearerToken(reqCtx, bearerToken) + reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims) + h.next.ServeHTTP(rw, r.WithContext(reqCtx)) +} + +// GetBearerTokenFromRequest extracts jwt token from request headers. +func GetBearerTokenFromRequest(r *http.Request) string { + authHeader := strings.TrimSpace(r.Header.Get(HeaderAuthorization)) + if strings.HasPrefix(authHeader, "Bearer ") || strings.HasPrefix(authHeader, "bearer ") { + return authHeader[7:] + } + return "" +} + +// NewContextWithJWTClaims creates a new context with JWT claims. +func NewContextWithJWTClaims(ctx context.Context, jwtClaims *jwt.Claims) context.Context { + return context.WithValue(ctx, ctxKeyJWTClaims, jwtClaims) +} + +// GetJWTClaimsFromContext extracts JWT claims from the context. +func GetJWTClaimsFromContext(ctx context.Context) *jwt.Claims { + value := ctx.Value(ctxKeyJWTClaims) + if value == nil { + return nil + } + return value.(*jwt.Claims) +} + +// NewContextWithBearerToken creates a new context with token. +func NewContextWithBearerToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, ctxKeyBearerToken, token) +} + +// GetBearerTokenFromContext extracts token from the context. +func GetBearerTokenFromContext(ctx context.Context) string { + value := ctx.Value(ctxKeyBearerToken) + if value == nil { + return "" + } + return value.(string) +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..45c58c5 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,242 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/acronis/go-appkit/testutil" + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/jwt" +) + +type mockJWTAuthMiddlewareNextHandler struct { + called int + jwtClaims *jwt.Claims +} + +func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + h.called++ + h.jwtClaims = GetJWTClaimsFromContext(r.Context()) +} + +type mockJWTParser struct { + parseCalled int + claimsToReturn *jwt.Claims + errToReturn error + passedToken string +} + +func (p *mockJWTParser) Parse(_ context.Context, token string) (*jwt.Claims, error) { + p.parseCalled++ + p.passedToken = token + return p.claimsToReturn, p.errToReturn +} + +type mockTokenIntrospector struct { + introspectCalled int + introspectedToken string + resultToReturn idptoken.IntrospectionResult + errToReturn error +} + +func (i *mockTokenIntrospector) IntrospectToken(_ context.Context, token string) (idptoken.IntrospectionResult, error) { + i.introspectCalled++ + i.introspectedToken = token + return i.resultToReturn, i.errToReturn +} + +func TestJWTAuthMiddleware(t *testing.T) { + const errDomain = "TestDomain" + + t.Run("bearer token is missing", func(t *testing.T) { + for _, headerVal := range []string{"", "foobar", "Bearer", "Bearer "} { + parser := &mockJWTParser{} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + if headerVal != "" { + req.Header.Set(HeaderAuthorization, headerVal) + } + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeBearerTokenMissing) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 0, next.called) + require.Nil(t, next.jwtClaims) + } + }) + + t.Run("authentication failed", func(t *testing.T) { + parser := &mockJWTParser{errToReturn: errors.New("malformed JWT")} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer foobar") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 0, next.called) + require.Nil(t, next.jwtClaims) + }) + + t.Run("ok", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) + + t.Run("introspection failed", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + introspector := &mockTokenIntrospector{errToReturn: errors.New("introspection failed")} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 0, next.called) + }) + + t.Run("introspection is not needed", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenIntrospectionNotNeeded} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) + + t.Run("ok, token is not introspectable", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenNotIntrospectable} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) + + t.Run("authentication failed, token is introspected but inactive", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{} + introspector := &mockTokenIntrospector{resultToReturn: idptoken.IntrospectionResult{Active: false}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 0, next.called) + }) + + t.Run("ok, token is introspected and active", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{} + introspector := &mockTokenIntrospector{resultToReturn: idptoken.IntrospectionResult{Active: true, Claims: jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, introspector.introspectCalled) + require.Equal(t, 0, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.Equal(t, issuer, next.jwtClaims.Issuer) + }) +} + +func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { + const errDomain = "TestDomain" + + t.Run("authorization failed", func(t *testing.T) { + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + + testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, ErrCodeAuthorizationFailed) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 0, next.called) + require.Nil(t, next.jwtClaims) + }) + + t.Run("ok", func(t *testing.T) { + scope := []jwt.AccessPolicy{{ResourceNamespace: "my-service", Role: "admin"}} + parser := &mockJWTParser{claimsToReturn: &jwt.Claims{Scope: scope}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + resp := httptest.NewRecorder() + + verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.NotNil(t, next.jwtClaims) + require.EqualValues(t, scope, next.jwtClaims.Scope) + }) +} diff --git a/min-coverage.txt b/min-coverage.txt new file mode 100644 index 0000000..d15a2cc --- /dev/null +++ b/min-coverage.txt @@ -0,0 +1 @@ +80