diff --git a/README.md b/README.md index 31bccbb..0176f5d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ # Toolkit for authentication and authorization in Go services +## Installation + +``` +go get -u github.com/acronis/go-authkit +``` + ## Features - Authenticate HTTP requests with JWT tokens via middleware that can be configured via YAML/JSON file or environment variables. diff --git a/auth.go b/auth.go index ea57dd4..41951fa 100644 --- a/auth.go +++ b/auth.go @@ -12,23 +12,17 @@ import ( "fmt" "net/http" "os" - "time" "github.com/acronis/go-appkit/log" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/idputil" "github.com/acronis/go-authkit/jwks" "github.com/acronis/go-authkit/jwt" ) -// Default values. -const ( - DefaultHTTPClientRequestTimeout = time.Second * 30 - DefaultGRPCClientRequestTimeout = time.Second * 30 -) - // NewJWTParser creates a new JWTParser with the given configuration. // If cfg.JWT.ClaimsCache.Enabled is true, then jwt.CachingParser created, otherwise - jwt.Parser. func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { @@ -36,25 +30,23 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { for _, opt := range opts { opt(&options) } - logger := options.logger - if logger == nil { - logger = log.NewDisabledLogger() - } + + logger := idputil.PrepareLogger(options.logger) // Make caching JWKS client. jwksCacheUpdateMinInterval := cfg.JWKS.Cache.UpdateMinInterval if jwksCacheUpdateMinInterval == 0 { jwksCacheUpdateMinInterval = jwks.DefaultCacheUpdateMinInterval } - httpClientRequestTimeout := cfg.HTTPClient.RequestTimeout - if httpClientRequestTimeout == 0 { - httpClientRequestTimeout = DefaultHTTPClientRequestTimeout - } jwksClientOpts := jwks.CachingClientOpts{ - ClientOpts: jwks.ClientOpts{PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel}, + ClientOpts: jwks.ClientOpts{ + Logger: logger, + HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, logger), + PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel, + }, CacheUpdateMinInterval: jwksCacheUpdateMinInterval, } - jwksClient := jwks.NewCachingClientWithOpts(&http.Client{Timeout: httpClientRequestTimeout}, logger, jwksClientOpts) + jwksClient := jwks.NewCachingClientWithOpts(jwksClientOpts) // Make JWT parser. @@ -134,10 +126,8 @@ func NewTokenIntrospector( for _, opt := range opts { opt(&options) } - logger := options.logger - if logger == nil { - logger = log.NewDisabledLogger() - } + + logger := idputil.PrepareLogger(options.logger) if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { logger.Warn("list of trusted issuers is empty, jwt introspection may not work properly") @@ -156,15 +146,10 @@ func NewTokenIntrospector( } } - httpClientRequestTimeout := cfg.HTTPClient.RequestTimeout - if httpClientRequestTimeout == 0 { - httpClientRequestTimeout = DefaultHTTPClientRequestTimeout - } - introspectorOpts := idptoken.IntrospectorOpts{ StaticHTTPEndpoint: cfg.Introspection.Endpoint, GRPCClient: grpcClient, - HTTPClient: &http.Client{Timeout: httpClientRequestTimeout}, + HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, logger), AccessTokenScope: cfg.Introspection.AccessTokenScope, Logger: logger, ScopeFilter: scopeFilter, diff --git a/config.go b/config.go index b18ee1b..d34919b 100644 --- a/config.go +++ b/config.go @@ -14,6 +14,7 @@ import ( "github.com/acronis/go-appkit/config" "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/idputil" "github.com/acronis/go-authkit/jwks" "github.com/acronis/go-authkit/jwt" ) @@ -142,8 +143,8 @@ func (c *Config) KeyPrefix() string { // SetProviderDefaults sets default configuration values for auth in config.DataProvider. func (c *Config) SetProviderDefaults(dp config.DataProvider) { - dp.SetDefault(cfgKeyHTTPClientRequestTimeout, DefaultHTTPClientRequestTimeout.String()) - dp.SetDefault(cfgKeyGRPCClientRequestTimeout, DefaultGRPCClientRequestTimeout.String()) + dp.SetDefault(cfgKeyHTTPClientRequestTimeout, idputil.DefaultHTTPRequestTimeout.String()) + dp.SetDefault(cfgKeyGRPCClientRequestTimeout, idptoken.DefaultGRPCClientRequestTimeout.String()) dp.SetDefault(cfgKeyJWTClaimsCacheMaxEntries, jwt.DefaultClaimsCacheMaxEntries) dp.SetDefault(cfgKeyJWKSCacheUpdateMinInterval, jwks.DefaultCacheUpdateMinInterval.String()) dp.SetDefault(cfgKeyIntrospectionClaimsCacheMaxEntries, idptoken.DefaultIntrospectionClaimsCacheMaxEntries) diff --git a/examples/idp-test-server/main.go b/examples/idp-test-server/main.go index c47faff..aea6e96 100644 --- a/examples/idp-test-server/main.go +++ b/examples/idp-test-server/main.go @@ -33,7 +33,8 @@ func runApp() error { logger, loggerClose := log.NewLogger(&log.Config{Output: log.OutputStdout, Level: log.LevelInfo, Format: log.FormatJSON}) defer loggerClose() - jwtParser := jwt.NewParser(jwks.NewCachingClient(&http.Client{Timeout: time.Second * 30}, logger), logger) + jwksClientOpts := jwks.CachingClientOpts{ClientOpts: jwks.ClientOpts{Logger: logger}} + jwtParser := jwt.NewParser(jwks.NewCachingClientWithOpts(jwksClientOpts), logger) _ = jwtParser.AddTrustedIssuerURL("http://" + idpAddr) idpSrv := idptest.NewHTTPServer( idptest.WithHTTPAddress(idpAddr), diff --git a/idptest/jwt.go b/idptest/jwt.go index 172dee5..51660b1 100644 --- a/idptest/jwt.go +++ b/idptest/jwt.go @@ -21,6 +21,16 @@ func SignToken(token *jwtgo.Token, rsaPrivateKey interface{}) (string, error) { return token.SignedString(rsaPrivateKey) } +// MustSignToken signs token with key. +// It panics if error occurs. +func MustSignToken(token *jwtgo.Token, rsaPrivateKey interface{}) string { + s, err := SignToken(token, rsaPrivateKey) + if err != nil { + panic(err) + } + return s +} + // MakeTokenStringWithHeader create test signed token with claims and headers. func MakeTokenStringWithHeader( claims jwtgo.Claims, kid string, rsaPrivateKey interface{}, header map[string]interface{}, diff --git a/idptest/jwt_test.go b/idptest/jwt_test.go index a8056e9..b3caa91 100644 --- a/idptest/jwt_test.go +++ b/idptest/jwt_test.go @@ -8,7 +8,6 @@ package idptest import ( "context" - "net/http" "net/http/httptest" "testing" "time" @@ -42,7 +41,7 @@ func TestMakeTokenStringWithHeader(t *testing.T) { }, } - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), MustMakeTokenStringSignedWithTestKey(jwtClaims)) require.NoError(t, err) diff --git a/idptoken/caching_introspector_test.go b/idptoken/caching_introspector_test.go index b6575f7..a4d61df 100644 --- a/idptoken/caching_introspector_test.go +++ b/idptoken/caching_introspector_test.go @@ -8,7 +8,6 @@ package idptoken_test import ( "context" - "net/http" "net/url" gotesting "testing" "time" @@ -36,7 +35,7 @@ func TestCachingIntrospector_IntrospectToken(t *gotesting.T) { tokenProvider := idptest.NewSimpleTokenProvider(accessToken) logger := log.NewDisabledLogger() - jwtParser := jwt.NewParser(jwks.NewClient(http.DefaultClient, logger), logger) + jwtParser := jwt.NewParser(jwks.NewClient(), logger) require.NoError(t, jwtParser.AddTrustedIssuerURL(idpSrv.URL())) serverIntrospector.JWTParser = jwtParser diff --git a/idptoken/grpc_client.go b/idptoken/grpc_client.go index 7b1f743..b787d8e 100644 --- a/idptoken/grpc_client.go +++ b/idptoken/grpc_client.go @@ -22,10 +22,13 @@ import ( grpcstatus "google.golang.org/grpc/status" "github.com/acronis/go-authkit/idptoken/pb" + "github.com/acronis/go-authkit/internal/idputil" "github.com/acronis/go-authkit/internal/metrics" "github.com/acronis/go-authkit/jwt" ) +const DefaultGRPCClientRequestTimeout = time.Second * 30 + // GRPCClientOpts contains options for the GRPCClient. type GRPCClientOpts struct { // Logger is a logger for the client. @@ -61,11 +64,9 @@ func NewGRPCClient( func NewGRPCClientWithOpts( target string, transportCreds credentials.TransportCredentials, opts GRPCClientOpts, ) (*GRPCClient, error) { - if opts.Logger == nil { - opts.Logger = log.NewDisabledLogger() - } + opts.Logger = idputil.PrepareLogger(opts.Logger) if opts.RequestTimeout == 0 { - opts.RequestTimeout = time.Second * 30 + opts.RequestTimeout = DefaultGRPCClientRequestTimeout } conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(transportCreds), diff --git a/idptoken/introspector.go b/idptoken/introspector.go index c996631..e8a18f7 100644 --- a/idptoken/introspector.go +++ b/idptoken/introspector.go @@ -27,8 +27,6 @@ import ( "github.com/acronis/go-authkit/jwt" ) -const DefaultRequestTimeout = 30 * time.Second - const JWTTypeAccessToken = "at+jwt" const TokenTypeBearer = "bearer" @@ -137,11 +135,9 @@ func NewIntrospector(tokenProvider IntrospectionTokenProvider) *Introspector { // NewIntrospectorWithOpts creates a new Introspector with the given token provider and options. // See IntrospectorOpts for more details. func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opts IntrospectorOpts) *Introspector { + opts.Logger = idputil.PrepareLogger(opts.Logger) if opts.HTTPClient == nil { - opts.HTTPClient = &http.Client{Timeout: DefaultRequestTimeout} - } - if opts.Logger == nil { - opts.Logger = log.NewDisabledLogger() + opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.Logger) } values := url.Values{} diff --git a/idptoken/introspector_test.go b/idptoken/introspector_test.go index 813c8ea..74d7f77 100644 --- a/idptoken/introspector_test.go +++ b/idptoken/introspector_test.go @@ -8,7 +8,6 @@ package idptoken_test import ( "context" - "net/http" "net/url" gotesting "testing" "time" @@ -42,8 +41,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { const accessToken = "access-token-with-introspection-permission" tokenProvider := idptest.NewSimpleTokenProvider(accessToken) - logger := log.NewDisabledLogger() - jwtParser := jwt.NewParser(jwks.NewClient(http.DefaultClient, logger), logger) + jwtParser := jwt.NewParser(jwks.NewClient(), log.NewDisabledLogger()) require.NoError(t, jwtParser.AddTrustedIssuerURL(httpIDPSrv.URL())) httpServerIntrospector.JWTParser = jwtParser grpcServerIntrospector.JWTParser = jwtParser diff --git a/idptoken/provider.go b/idptoken/provider.go index 80f44a5..0b74818 100644 --- a/idptoken/provider.go +++ b/idptoken/provider.go @@ -161,29 +161,42 @@ type MultiSourceProvider struct { nextRefresh time.Time } -// NewMultiSourceProviderWithOpts returns a new instance of MultiSourceProvider with custom settings -func NewMultiSourceProviderWithOpts( - httpClient *http.Client, opts ProviderOpts, sources ...Source, -) *MultiSourceProvider { - p := MultiSourceProvider{} +// NewMultiSourceProvider returns a new instance of MultiSourceProvider with default settings +func NewMultiSourceProvider(sources []Source) *MultiSourceProvider { + return NewMultiSourceProviderWithOpts(sources, ProviderOpts{}) +} - if opts.Logger == nil { - opts.Logger = log.NewDisabledLogger() +// NewMultiSourceProviderWithOpts returns a new instance of MultiSourceProvider with custom settings +func NewMultiSourceProviderWithOpts(sources []Source, opts ProviderOpts) *MultiSourceProvider { + p := MultiSourceProvider{ + rescheduleSignal: make(chan struct{}, 1), + nextRefresh: zeroTime, + minRefreshPeriod: opts.MinRefreshPeriod, + logger: idputil.PrepareLogger(opts.Logger), + tokenIssuers: make(map[string]*oauth2Issuer), + promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider"), + customHeaders: opts.CustomHeaders, + cache: opts.CustomCacheInstance, + httpClient: opts.HTTPClient, + } + + if p.minRefreshPeriod == 0 { + p.minRefreshPeriod = defaultMinRefreshPeriod + } + if p.cache == nil { + p.cache = NewInMemoryTokenCache() + } + if p.httpClient == nil { + p.httpClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, p.logger) } - if opts.MinRefreshPeriod == 0 { - opts.MinRefreshPeriod = defaultMinRefreshPeriod + for _, source := range sources { + p.RegisterSource(source) } - p.init(httpClient, opts, sources...) return &p } -// NewMultiSourceProvider returns a new instance of MultiSourceProvider with default settings -func NewMultiSourceProvider(httpClient *http.Client) *MultiSourceProvider { - return NewMultiSourceProviderWithOpts(httpClient, ProviderOpts{}) -} - // RegisterSource allows registering a new Source into MultiSourceProvider func (p *MultiSourceProvider) RegisterSource(source Source) { key := keyForIssuer(source.ClientID, source.URL) @@ -240,25 +253,22 @@ func (p *MultiSourceProvider) issueToken( } _, errEns, _ := p.sfGroup.Do(keyForIssuer(clientID, sourceURL), func() (interface{}, error) { - return nil, issuer.EnsureIssuerURL(ctx, headers) + return nil, issuer.ensureTokenURL(ctx, headers) }) - if errEns != nil { p.logger.Error(fmt.Sprintf("(%s, %s): ensure issuer URL", sourceURL, clientID), log.Error(errEns)) return TokenData{}, errEns } sortedScope := uniqAndSort(scope) - key := keyForCache(clientID, issuer.loadIssuerURL(), sortedScope) - + key := keyForCache(clientID, issuer.loadTokenURL(), sortedScope) token, err, _ := p.sfGroup.Do(key, func() (interface{}, error) { - result, issErr := issuer.IssueToken(ctx, headers, sortedScope) + result, issErr := issuer.issueToken(ctx, headers, sortedScope) p.cacheToken(result, sourceURL) return result, issErr }) - if err != nil { - p.logger.Error(fmt.Sprintf("(%s, %s): issuing token", issuer.loadIssuerURL(), clientID), log.Error(err)) + p.logger.Error(fmt.Sprintf("(%s, %s): issuing token", issuer.loadTokenURL(), clientID), log.Error(err)) return TokenData{}, err } @@ -324,11 +334,11 @@ func (p *MultiSourceProvider) getCachedOrInvalidate(clientID, sourceURL string, if !found { return TokenData{}, fmt.Errorf("(%s, %s): not registered", sourceURL, clientID) } - if issuer.loadIssuerURL() == "" { + if issuer.loadTokenURL() == "" { return TokenData{}, fmt.Errorf("(%s, %s): issuer URL not acquired", sourceURL, clientID) } - key := keyForCache(clientID, issuer.loadIssuerURL(), uniqAndSort(scope)) + key := keyForCache(clientID, issuer.loadTokenURL(), uniqAndSort(scope)) details := p.cache.Get(key) if details == nil { return TokenData{}, errors.New("token not found in cache") @@ -456,28 +466,6 @@ func (p *MultiSourceProvider) refreshLoop(ctx context.Context) { } } -func (p *MultiSourceProvider) init(httpClient *http.Client, opts ProviderOpts, sources ...Source) { - if httpClient == nil { - panic("httpClient is mandatory") - } - p.cache = opts.CustomCacheInstance - if p.cache == nil { - p.cache = NewInMemoryTokenCache() - } - p.rescheduleSignal = make(chan struct{}, 1) - p.nextRefresh = zeroTime - p.minRefreshPeriod = opts.MinRefreshPeriod - p.logger = opts.Logger - p.httpClient = httpClient - p.tokenIssuers = make(map[string]*oauth2Issuer) - p.promMetrics = metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider") - p.customHeaders = opts.CustomHeaders - - for _, source := range sources { - p.RegisterSource(source) - } -} - // Provider is a caching token provider for a single credentials set type Provider struct { provider *MultiSourceProvider @@ -485,15 +473,15 @@ type Provider struct { } // NewProvider returns a new instance of Provider -func NewProvider(httpClient *http.Client, source Source) *Provider { - return NewProviderWithOpts(httpClient, ProviderOpts{}, source) +func NewProvider(source Source) *Provider { + return NewProviderWithOpts(source, ProviderOpts{}) } // NewProviderWithOpts returns a new instance of Provider with custom options -func NewProviderWithOpts(httpClient *http.Client, opts ProviderOpts, source Source) *Provider { +func NewProviderWithOpts(source Source, opts ProviderOpts) *Provider { mp := Provider{ source: source, - provider: NewMultiSourceProviderWithOpts(httpClient, opts, source), + provider: NewMultiSourceProviderWithOpts([]Source{source}, opts), } return &mp } @@ -527,7 +515,7 @@ type oauth2Issuer struct { clientSecret string httpClient *http.Client logger log.FieldLogger - issuerURL atomic.Value + tokenURL atomic.Value promMetrics *metrics.PrometheusMetrics } @@ -542,15 +530,15 @@ func (p *MultiSourceProvider) newOAuth2Issuer(baseURL, clientID, clientSecret st } } -func (ti *oauth2Issuer) loadIssuerURL() string { - if v := ti.issuerURL.Load(); v != nil { +func (ti *oauth2Issuer) loadTokenURL() string { + if v := ti.tokenURL.Load(); v != nil { return v.(string) } return "" } -func (ti *oauth2Issuer) EnsureIssuerURL(ctx context.Context, customHeaders map[string]string) error { - if ti.loadIssuerURL() != "" { +func (ti *oauth2Issuer) ensureTokenURL(ctx context.Context, customHeaders map[string]string) error { + if ti.loadTokenURL() != "" { return nil } @@ -565,16 +553,16 @@ func (ti *oauth2Issuer) EnsureIssuerURL(ctx context.Context, customHeaders map[s return fmt.Errorf("(%s, %s): issuer have returned a non-valid URL %q: %w", ti.baseURL, ti.clientID, openIDCfg.TokenURL, err) } - ti.issuerURL.Store(openIDCfg.TokenURL) + ti.tokenURL.Store(openIDCfg.TokenURL) return nil } -func (ti *oauth2Issuer) IssueToken( +func (ti *oauth2Issuer) issueToken( ctx context.Context, customHeaders map[string]string, scope []string, ) (TokenData, error) { - issuerURL := ti.loadIssuerURL() - if issuerURL == "" { - panic("must first ensure issuerURL") + tokenURL := ti.loadTokenURL() + if tokenURL == "" { + return TokenData{}, fmt.Errorf("token URL is empty") } values := url.Values{} values.Add("grant_type", "client_credentials") @@ -582,7 +570,7 @@ func (ti *oauth2Issuer) IssueToken( if scopeStr != "" { values.Add("scope", scopeStr) } - req, reqErr := http.NewRequest(http.MethodPost, issuerURL, strings.NewReader(values.Encode())) + req, reqErr := http.NewRequest(http.MethodPost, tokenURL, strings.NewReader(values.Encode())) if reqErr != nil { return TokenData{}, reqErr } @@ -598,13 +586,13 @@ func (ti *oauth2Issuer) IssueToken( elapsed := time.Since(start) if err != nil { - ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, issuerURL, 0, elapsed, metrics.HTTPRequestErrorDo) + ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, tokenURL, 0, elapsed, metrics.HTTPRequestErrorDo) return TokenData{}, fmt.Errorf("do http request: %w", err) } defer func() { if err := resp.Body.Close(); err != nil { ti.logger.Error( - fmt.Sprintf("(%s, %s): closing body", ti.loadIssuerURL(), ti.clientID), log.Error(err), + fmt.Sprintf("(%s, %s): closing body", ti.loadTokenURL(), ti.clientID), log.Error(err), ) } }() @@ -612,26 +600,26 @@ func (ti *oauth2Issuer) IssueToken( tokenResponse := tokenResponseBody{} if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { ti.promMetrics.ObserveHTTPClientRequest( - http.MethodPost, issuerURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) + http.MethodPost, tokenURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorDecodeBody) return TokenData{}, fmt.Errorf( - "(%s, %s): read and unmarshal IDP response: %w", ti.loadIssuerURL(), ti.clientID, err, + "(%s, %s): read and unmarshal IDP response: %w", ti.loadTokenURL(), ti.clientID, err, ) } if resp.StatusCode != http.StatusOK { ti.promMetrics.ObserveHTTPClientRequest( - http.MethodPost, issuerURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) - return TokenData{}, &UnexpectedIDPResponseError{HTTPCode: resp.StatusCode, IssueURL: ti.loadIssuerURL()} + http.MethodPost, tokenURL, resp.StatusCode, elapsed, metrics.HTTPRequestErrorUnexpectedStatusCode) + return TokenData{}, &UnexpectedIDPResponseError{HTTPCode: resp.StatusCode, IssueURL: ti.loadTokenURL()} } - ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, issuerURL, resp.StatusCode, elapsed, "") + ti.promMetrics.ObserveHTTPClientRequest(http.MethodPost, tokenURL, resp.StatusCode, elapsed, "") expires := time.Now().Add(time.Second * time.Duration(tokenResponse.ExpiresIn)) - ti.logger.Infof("(%s, %s): issued token, expires on %s", ti.loadIssuerURL(), ti.clientID, expires.UTC()) + ti.logger.Infof("(%s, %s): issued token, expires on %s", ti.loadTokenURL(), ti.clientID, expires.UTC()) return TokenData{ Data: tokenResponse.AccessToken, Scope: scope, Expires: expires, - issueURL: ti.loadIssuerURL(), + issueURL: ti.loadTokenURL(), ClientID: ti.clientID, }, nil } @@ -641,6 +629,9 @@ type ProviderOpts struct { // Logger is a logger for MultiSourceProvider. Logger log.FieldLogger + // HTTPClient is an HTTP client for MultiSourceProvider. + HTTPClient *http.Client + // MinRefreshPeriod is a minimal possible refresh interval for MultiSourceProvider's token cache. MinRefreshPeriod time.Duration diff --git a/idptoken/provider_test.go b/idptoken/provider_test.go index c0a3363..38ca903 100644 --- a/idptoken/provider_test.go +++ b/idptoken/provider_test.go @@ -104,7 +104,7 @@ func TestProviderWithCache(t *testing.T) { MinRefreshPeriod: 1 * time.Second, CustomHeaders: map[string]string{"User-Agent": expectedUserAgent}, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) _, tokenErr := provider.GetTokenWithHeaders( context.Background(), testClientID, server.URL(), @@ -133,7 +133,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) cachedToken, tokenErr := provider.GetToken( context.Background(), testClientID, server.URL(), "tenants:read", @@ -172,7 +172,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) tokenOld, tokenErr := provider.GetToken( @@ -205,7 +205,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 10 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) tokenOld, tokenErr := provider.GetToken( @@ -242,7 +242,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) _, tokenErr := provider.GetToken( context.Background(), testClientID, server.URL(), "tenants:read", @@ -277,7 +277,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) _, tokenErr := provider.GetToken(context.Background(), testClientID, server.URL(), "tenants:read") require.NoError(t, tokenErr) @@ -318,7 +318,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials...) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) _, tokenErr := provider.GetToken( context.Background(), testClientID, server.URL(), "tenants:read", @@ -357,7 +357,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, opts, credentials[0]) + provider := idptoken.NewMultiSourceProviderWithOpts(credentials[:1], opts) go provider.RefreshTokensPeriodically(context.Background()) provider.RegisterSource(credentials[1]) _, tokenErr := provider.GetToken( @@ -380,7 +380,7 @@ func TestProviderWithCache(t *testing.T) { Logger: logger, MinRefreshPeriod: 1 * time.Second, } - provider := idptoken.NewProviderWithOpts(httpClient, opts, credentials) + provider := idptoken.NewProviderWithOpts(credentials, opts) go provider.RefreshTokensPeriodically(context.Background()) _, tokenErr := provider.GetToken(context.Background(), "tenants:read") require.NoError(t, tokenErr) @@ -396,7 +396,7 @@ func TestProviderWithCache(t *testing.T) { credentials := idptoken.Source{ ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), } - provider := idptoken.NewMultiSourceProvider(httpClient) + provider := idptoken.NewMultiSourceProviderWithOpts(nil, idptoken.ProviderOpts{HTTPClient: httpClient}) go provider.RefreshTokensPeriodically(context.Background()) provider.RegisterSource(credentials) _, tokenErr := provider.GetToken( @@ -416,7 +416,8 @@ func TestProviderWithCache(t *testing.T) { ClientID: testClientID, ClientSecret: "DAGztV5L2hMZyECzer6SXS", URL: server.URL(), } tokenCache := idptoken.NewInMemoryTokenCache() - provider := idptoken.NewMultiSourceProviderWithOpts(httpClient, idptoken.ProviderOpts{CustomCacheInstance: tokenCache}) + provider := idptoken.NewMultiSourceProviderWithOpts(nil, idptoken.ProviderOpts{ + CustomCacheInstance: tokenCache, HTTPClient: httpClient}) go provider.RefreshTokensPeriodically(context.Background()) provider.RegisterSource(credentials) credentials.ClientSecret = "newsecret" diff --git a/internal/idputil/idp_util.go b/internal/idputil/idp_util.go new file mode 100644 index 0000000..9774fa4 --- /dev/null +++ b/internal/idputil/idp_util.go @@ -0,0 +1,41 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package idputil + +import ( + "fmt" + "net/http" + "time" + + "github.com/acronis/go-appkit/httpclient" + "github.com/acronis/go-appkit/log" + + "github.com/acronis/go-authkit/internal/libinfo" +) + +const ( + DefaultHTTPRequestTimeout = 30 * time.Second + DefaultHTTPRequestMaxRetryAttempts = 3 +) + +func MakeDefaultHTTPClient(reqTimeout time.Duration, logger log.FieldLogger) *http.Client { + if reqTimeout == 0 { + reqTimeout = DefaultHTTPRequestTimeout + } + var tr http.RoundTripper = http.DefaultTransport.(*http.Transport).Clone() + tr, _ = httpclient.NewRetryableRoundTripperWithOpts(tr, httpclient.RetryableRoundTripperOpts{ + MaxRetryAttempts: DefaultHTTPRequestMaxRetryAttempts, Logger: logger}) // error is always nil + tr = httpclient.NewUserAgentRoundTripper(tr, libinfo.LibName+"/"+libinfo.GetLibVersion()) + return &http.Client{Timeout: reqTimeout, Transport: tr} +} + +func PrepareLogger(logger log.FieldLogger) log.FieldLogger { + if logger == nil { + return log.NewDisabledLogger() + } + return log.NewPrefixedLogger(logger, fmt.Sprintf("[%s/%s] ", libinfo.LibName, libinfo.GetLibVersion())) +} diff --git a/internal/libinfo/lib_info.go b/internal/libinfo/lib_info.go index b08d088..ea8cd2e 100644 --- a/internal/libinfo/lib_info.go +++ b/internal/libinfo/lib_info.go @@ -5,22 +5,3 @@ Released under MIT license. */ package libinfo - -import ( - "fmt" -) - -const libName = "go-authkit" - -const libPath = "github.com/acronis/" + libName - -func MakeUserAgent(prependedUserAgent string) string { - if prependedUserAgent != "" { - prependedUserAgent += " " - } - return prependedUserAgent + " " + libName + "/" + GetLibVersion() -} - -func GetLogPrefix() string { - return fmt.Sprintf("[%s/%s]", libName, GetLibVersion()) -} diff --git a/internal/libinfo/version.go b/internal/libinfo/version.go index a533b5b..cefbc8d 100644 --- a/internal/libinfo/version.go +++ b/internal/libinfo/version.go @@ -12,6 +12,10 @@ import ( "runtime/debug" ) +const LibName = "go-authkit" + +const libPath = "github.com/acronis/" + LibName + var libVersion string var libVersionOnce sync.Once diff --git a/jwks/caching_client.go b/jwks/caching_client.go index 6954618..a448b57 100644 --- a/jwks/caching_client.go +++ b/jwks/caching_client.go @@ -9,11 +9,9 @@ package jwks import ( "context" "fmt" - "net/http" "sync" "time" - "github.com/acronis/go-appkit/log" "github.com/acronis/go-appkit/lrucache" ) @@ -44,17 +42,17 @@ type issuerCacheEntry struct { } // NewCachingClient returns a new Client that can cache fetched data. -func NewCachingClient(httpClient *http.Client, logger log.FieldLogger) *CachingClient { - return NewCachingClientWithOpts(httpClient, logger, CachingClientOpts{}) +func NewCachingClient() *CachingClient { + return NewCachingClientWithOpts(CachingClientOpts{}) } // NewCachingClientWithOpts returns a new Client that can cache fetched data with options. -func NewCachingClientWithOpts(httpClient *http.Client, logger log.FieldLogger, opts CachingClientOpts) *CachingClient { +func NewCachingClientWithOpts(opts CachingClientOpts) *CachingClient { if opts.CacheUpdateMinInterval == 0 { opts.CacheUpdateMinInterval = DefaultCacheUpdateMinInterval } return &CachingClient{ - rawClient: NewClientWithOpts(httpClient, logger, opts.ClientOpts), + rawClient: NewClientWithOpts(opts.ClientOpts), issuerCache: make(map[string]issuerCacheEntry), cacheUpdateMinInterval: opts.CacheUpdateMinInterval, } diff --git a/jwks/caching_client_test.go b/jwks/caching_client_test.go index 395bbba..88ff9b6 100644 --- a/jwks/caching_client_test.go +++ b/jwks/caching_client_test.go @@ -10,13 +10,11 @@ import ( "context" "crypto/rsa" "errors" - "net/http" "net/http/httptest" "sync" "testing" "time" - "github.com/acronis/go-appkit/log" "github.com/stretchr/testify/require" "github.com/acronis/go-authkit/idptest" @@ -32,8 +30,7 @@ func TestCachingClient_GetRSAPublicKey(t *testing.T) { issuerConfigServer := httptest.NewServer(issuerConfigHandler) defer issuerConfigServer.Close() - cachingClient := jwks.NewCachingClientWithOpts(http.DefaultClient, log.NewDisabledLogger(), - jwks.CachingClientOpts{CacheUpdateMinInterval: time.Second * 10}) + cachingClient := jwks.NewCachingClientWithOpts(jwks.CachingClientOpts{CacheUpdateMinInterval: time.Second * 10}) var wg sync.WaitGroup const callsNum = 10 wg.Add(callsNum) @@ -75,8 +72,7 @@ func TestCachingClient_GetRSAPublicKey(t *testing.T) { const unknownKeyID = "77777777-7777-7777-7777-777777777777" const cacheUpdateMinInterval = time.Second * 1 - cachingClient := jwks.NewCachingClientWithOpts(http.DefaultClient, log.NewDisabledLogger(), - jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) + cachingClient := jwks.NewCachingClientWithOpts(jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) doGetPublicKeyByUnknownID := func(callsNum int) { t.Helper() diff --git a/jwks/client.go b/jwks/client.go index 9af6075..69e3e0f 100644 --- a/jwks/client.go +++ b/jwks/client.go @@ -31,6 +31,12 @@ type jwksData struct { // ClientOpts contains options for the JWKS client. type ClientOpts struct { + // HTTPClient is an HTTP client for making requests. + HTTPClient *http.Client + + // Logger is a logger for the client. + Logger log.FieldLogger + // PrometheusLibInstanceLabel is a label for Prometheus metrics. // It allows distinguishing metrics from different instances of the same library. PrometheusLibInstanceLabel string @@ -47,14 +53,18 @@ type Client struct { } // NewClient returns a new Client. -func NewClient(httpClient *http.Client, logger log.FieldLogger) *Client { - return NewClientWithOpts(httpClient, logger, ClientOpts{}) +func NewClient() *Client { + return NewClientWithOpts(ClientOpts{}) } // NewClientWithOpts returns a new Client with options. -func NewClientWithOpts(httpClient *http.Client, logger log.FieldLogger, opts ClientOpts) *Client { +func NewClientWithOpts(opts ClientOpts) *Client { promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client") - return &Client{httpClient, logger, promMetrics} + opts.Logger = idputil.PrepareLogger(opts.Logger) + if opts.HTTPClient == nil { + opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.Logger) + } + return &Client{httpClient: opts.HTTPClient, logger: opts.Logger, promMetrics: promMetrics} } func (c *Client) getRSAPubKeysForIssuer(ctx context.Context, issuerURL string) (map[string]interface{}, error) { diff --git a/jwks/client_test.go b/jwks/client_test.go index b95e1f6..5bcca45 100644 --- a/jwks/client_test.go +++ b/jwks/client_test.go @@ -16,7 +16,6 @@ import ( "strings" "testing" - "github.com/acronis/go-appkit/log" "github.com/stretchr/testify/require" "github.com/acronis/go-authkit/idptest" @@ -30,7 +29,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) require.NoError(t, err) require.NotNil(t, pubKey) @@ -43,7 +42,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) issuerConfigServer.Close() // Close the server immediately. - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError @@ -59,7 +58,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { })) defer issuerConfigServer.Close() - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError @@ -76,7 +75,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { })) defer issuerConfigServer.Close() - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError @@ -93,7 +92,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) require.Error(t, err) var jwksErr *jwks.GetJWKSError @@ -112,7 +111,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, idptest.TestKeyID) require.Error(t, err) var jwksErr *jwks.GetJWKSError @@ -131,7 +130,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { const unknownKeyID = "77777777-7777-7777-7777-777777777777" - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() pubKey, err := client.GetRSAPublicKey(context.Background(), issuerConfigServer.URL, unknownKeyID) require.Error(t, err) var jwkErr *jwks.JWKNotFoundError @@ -147,7 +146,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - client := jwks.NewClient(http.DefaultClient, log.NewDisabledLogger()) + client := jwks.NewClient() ctx, cancelCtxFn := context.WithCancel(context.Background()) cancelCtxFn() // Emulate canceling context. pubKey, err := client.GetRSAPublicKey(ctx, issuerConfigServer.URL, idptest.TestKeyID) diff --git a/jwt/caching_parser_test.go b/jwt/caching_parser_test.go index f0ede58..8b50adb 100644 --- a/jwt/caching_parser_test.go +++ b/jwt/caching_parser_test.go @@ -9,7 +9,6 @@ package jwt_test import ( "context" "crypto/sha256" - "net/http" "net/http/httptest" "testing" "time" @@ -54,7 +53,7 @@ func TestCachingParser_Parse(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) - parser, err := jwt.NewCachingParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser, err := jwt.NewCachingParser(jwks.NewCachingClient(), logger) require.NoError(t, err) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) @@ -89,7 +88,7 @@ func TestCachingParser_CheckExpiration(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(jwtTTL))}} tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) - parser, err := jwt.NewCachingParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser, err := jwt.NewCachingParser(jwks.NewCachingClient(), logger) require.NoError(t, err) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 33c12a8..a539890 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -8,7 +8,6 @@ package jwt_test import ( "context" - "net/http" "net/http/httptest" "testing" "time" @@ -43,7 +42,7 @@ func TestJWTParser_Parse(t *testing.T) { TOTPTime: time.Now().Unix(), SubType: "task_manager", } - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) @@ -59,7 +58,7 @@ func TestJWTParser_Parse(t *testing.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }, } - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) @@ -80,7 +79,7 @@ func TestJWTParser_Parse(t *testing.T) { "http://127.*", } for _, issURL := range issURLs { - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) require.NoError(t, parser.AddTrustedIssuerURL(issURL)) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) @@ -98,7 +97,7 @@ func TestJWTParser_Parse(t *testing.T) { }, Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, } - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ ExpectedAudience: []string{"*.cloud.com"}, }) require.NoError(t, parser.AddTrustedIssuerURL(issuerConfigServer.URL)) @@ -109,7 +108,7 @@ func TestJWTParser_Parse(t *testing.T) { }) t.Run("malformed jwt", func(t *testing.T) { - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) _, err := parser.Parse(context.Background(), "invalid-jwt") require.ErrorIs(t, err, jwtgo.ErrTokenMalformed) require.ErrorContains(t, err, "token contains an invalid number of segments") @@ -127,7 +126,7 @@ func TestJWTParser_Parse(t *testing.T) { tokenString, err := token.SignedString(jwtgo.UnsafeAllowNoneSignatureType) require.NoError(t, err) - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) _, err = parser.Parse(context.Background(), tokenString) require.ErrorIs(t, err, jwtgo.NoneSignatureTypeDisallowedError) }) @@ -136,7 +135,7 @@ func TestJWTParser_Parse(t *testing.T) { claims := &jwt.Claims{ RegisteredClaims: jwtgo.RegisteredClaims{Audience: []string{"https://cloud.acronis.com"}}, } - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) var issMissingErr *jwt.IssuerMissingError @@ -147,7 +146,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer", func(t *testing.T) { const issuer = "untrusted-issuer" claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) var issUntrustedErr *jwt.IssuerUntrustedError @@ -158,7 +157,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer url", func(t *testing.T) { const issuer = "https://3rd-party-idp.com" claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) require.NoError(t, parser.AddTrustedIssuerURL("https://*.acronis.com")) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) @@ -177,7 +176,7 @@ func TestJWTParser_Parse(t *testing.T) { }, Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, } - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ TrustedIssuerNotFoundFallback: func(ctx context.Context, p *jwt.Parser, iss string) (issURL string, issFound bool) { callbackCallCount++ addErr := p.AddTrustedIssuerURL(iss) @@ -200,7 +199,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt exp is missing", func(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss}} - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) @@ -210,7 +209,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt expired", func(t *testing.T) { expiresAt := time.Now().Add(-time.Second) claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(expiresAt)}} - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) @@ -224,7 +223,7 @@ func TestJWTParser_Parse(t *testing.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), NotBefore: jwtgo.NewNumericDate(notBefore), }} - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) @@ -236,7 +235,7 @@ func TestJWTParser_Parse(t *testing.T) { Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }} - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ RequireAudience: true, }) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) @@ -255,7 +254,7 @@ func TestJWTParser_Parse(t *testing.T) { Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }} - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(http.DefaultClient, logger), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ ExpectedAudience: []string{"expected-audience"}, }) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) @@ -280,7 +279,7 @@ func TestJWTParser_Parse(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} tokenString, err := idptest.MakeTokenString(claims, "737c5114f09b5ed05276bd4b520245982f7fb29f", idptest.GetTestRSAPrivateKey()) require.NoError(t, err) - jwksClient := jwks.NewCachingClientWithOpts(http.DefaultClient, logger, jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) + jwksClient := jwks.NewCachingClientWithOpts(jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) parser := jwt.NewParser(jwksClient, logger) parser.AddTrustedIssuer(testIss, openIDCfgServer2.URL) @@ -337,7 +336,7 @@ func TestParser_getURLForIssuer(t *testing.T) { tt := tests[i] t.Run(tt.Name, func(t *testing.T) { logger := log.NewDisabledLogger() - parser := jwt.NewParser(jwks.NewCachingClient(http.DefaultClient, logger), logger) + parser := jwt.NewParser(jwks.NewCachingClient(), logger) require.NoError(t, parser.AddTrustedIssuerURL(tt.IssURLPattern)) for _, issURL := range tt.TrustedIssURLs { u, ok := parser.GetURLForIssuer(issURL)