Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect Prometheus metrics with token introspection result status #25

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion idptoken/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func NewGRPCClientWithOpts(
client: pb.NewIDPTokenServiceClient(conn),
clientConn: conn,
reqTimeout: opts.RequestTimeout,
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "grpc_client"),
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceGRPCClient),
}, nil
}

Expand Down
4 changes: 1 addition & 3 deletions idptoken/introspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import (

const minAccessTokenProviderInvalidationInterval = time.Minute

const tokenIntrospectorPromSource = "token_introspector"

const (
// DefaultIntrospectionClaimsCacheMaxEntries is a default maximum number of entries in the claims cache.
// Claims cache is used for storing introspected active tokens.
Expand Down Expand Up @@ -250,7 +248,7 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt
}
scopeFilterFormURLEncoded := values.Encode()

promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource)
promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceTokenIntrospector)

claimsCache := makeIntrospectionClaimsCache(opts.ClaimsCache, DefaultIntrospectionClaimsCacheMaxEntries, promMetrics)
if opts.ClaimsCache.TTL == 0 {
Expand Down
2 changes: 1 addition & 1 deletion idptoken/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func NewMultiSourceProviderWithOpts(sources []Source, opts ProviderOpts) *MultiS
minRefreshPeriod: opts.MinRefreshPeriod,
logger: idputil.PrepareLogger(opts.Logger),
tokenIssuers: make(map[string]*oauth2Issuer),
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider"),
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceTokenProvider),
customHeaders: opts.CustomHeaders,
cache: opts.CustomCacheInstance,
httpClient: opts.HTTPClient,
Expand Down
4 changes: 2 additions & 2 deletions idptoken/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func TestProviderWithCache(t *testing.T) {
metrics.HTTPClientRequestLabelStatusCode: "500",
metrics.HTTPClientRequestLabelError: "unexpected_status_code",
}
promMetrics := metrics.GetPrometheusMetrics("", "token_provider")
promMetrics := metrics.GetPrometheusMetrics("", metrics.SourceTokenProvider)
hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram)
testutil.AssertSamplesCountInHistogram(t, hist, 1)
})
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestProviderWithCache(t *testing.T) {
metrics.HTTPClientRequestLabelStatusCode: "200",
metrics.HTTPClientRequestLabelError: "",
}
promMetrics := metrics.GetPrometheusMetrics("", "token_provider")
promMetrics := metrics.GetPrometheusMetrics("", metrics.SourceTokenProvider)
hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram)
testutil.AssertSamplesCountInHistogram(t, hist, 1)
})
Expand Down
41 changes: 39 additions & 2 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,31 @@ const (

GRPCClientRequestLabelMethod = "grpc_method"
GRPCClientRequestLabelCode = "grpc_code"

TokenIntrospectionLabelStatus = "status"
)

const (
HTTPRequestErrorDo = "do_request_error"
HTTPRequestErrorDecodeBody = "decode_body_error"
HTTPRequestErrorUnexpectedStatusCode = "unexpected_status_code"

TokenIntrospectionStatusActive = "active"
TokenIntrospectionStatusNotActive = "not_active"
TokenIntrospectionStatusNotNeeded = "not_needed"
TokenIntrospectionStatusNotIntrospectable = "not_introspectable"
TokenIntrospectionStatusError = "error"
)

type Source string

const (
SourceJWKSClient Source = "jwks_client"
SourceJWTParser Source = "jwt_parser"
SourceGRPCClient Source = "grpc_client"
SourceTokenIntrospector Source = "token_introspector"
SourceTokenProvider Source = "token_provider"
SourceHTTPMiddleware Source = "http_middleware"
)

var requestDurationBuckets = []float64{0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
Expand All @@ -58,12 +77,13 @@ var (
type PrometheusMetrics struct {
HTTPClientRequestDuration *prometheus.HistogramVec
GRPCClientRequestDuration *prometheus.HistogramVec
TokenIntrospectionsTotal *prometheus.CounterVec
TokenClaimsCache *lrucache.PrometheusMetrics
TokenNegativeCache *lrucache.PrometheusMetrics
EndpointDiscoveryCache *lrucache.PrometheusMetrics
}

func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics {
func GetPrometheusMetrics(instance string, source Source) *PrometheusMetrics {
prometheusMetricsOnce.Do(func() {
prometheusMetrics = newPrometheusMetrics()
prometheusMetrics.MustRegister()
Expand All @@ -73,7 +93,7 @@ func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics {
}
return prometheusMetrics.MustCurryWith(map[string]string{
PrometheusLibInstanceLabel: instance,
PrometheusLibSourceLabel: source,
PrometheusLibSourceLabel: string(source),
})
}

Expand All @@ -95,6 +115,7 @@ func newPrometheusMetrics() *PrometheusMetrics {
makeLabelNames(HTTPClientRequestLabelMethod, HTTPClientRequestLabelURL,
HTTPClientRequestLabelStatusCode, HTTPClientRequestLabelError),
)

grpcClientReqDuration := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: PrometheusNamespace,
Expand All @@ -106,6 +127,16 @@ func newPrometheusMetrics() *PrometheusMetrics {
makeLabelNames(GRPCClientRequestLabelMethod, GRPCClientRequestLabelCode),
)

tokenIntrospectionsTotal := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: PrometheusNamespace,
Name: "token_introspections_total",
Help: "Total number of tokens' introspections",
ConstLabels: PrometheusLabels(),
},
makeLabelNames(TokenIntrospectionLabelStatus),
)

tokenClaimsCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{
Namespace: PrometheusNamespace + "_token_claims",
ConstLabels: PrometheusLabels(),
Expand All @@ -127,6 +158,7 @@ func newPrometheusMetrics() *PrometheusMetrics {
return &PrometheusMetrics{
HTTPClientRequestDuration: httpClientReqDuration,
GRPCClientRequestDuration: grpcClientReqDuration,
TokenIntrospectionsTotal: tokenIntrospectionsTotal,
TokenClaimsCache: tokenClaimsCache,
TokenNegativeCache: tokenNegativeCache,
EndpointDiscoveryCache: endpointDiscoveryCache,
Expand All @@ -138,6 +170,7 @@ func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *Prometheus
return &PrometheusMetrics{
HTTPClientRequestDuration: pm.HTTPClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec),
GRPCClientRequestDuration: pm.GRPCClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec),
TokenIntrospectionsTotal: pm.TokenIntrospectionsTotal.MustCurryWith(labels),
TokenClaimsCache: pm.TokenClaimsCache.MustCurryWith(labels),
TokenNegativeCache: pm.TokenNegativeCache.MustCurryWith(labels),
EndpointDiscoveryCache: pm.EndpointDiscoveryCache.MustCurryWith(labels),
Expand Down Expand Up @@ -183,3 +216,7 @@ func (pm *PrometheusMetrics) ObserveGRPCClientRequest(
GRPCClientRequestLabelCode: code.String(),
}).Observe(elapsed.Seconds())
}

func (pm *PrometheusMetrics) IncTokenIntrospectionsTotal(status string) {
pm.TokenIntrospectionsTotal.With(prometheus.Labels{TokenIntrospectionLabelStatus: status}).Inc()
}
2 changes: 1 addition & 1 deletion jwks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewClient() *Client {

// NewClientWithOpts returns a new Client with options.
func NewClientWithOpts(opts ClientOpts) *Client {
promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client")
promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceJWKSClient)
if opts.HTTPClient == nil {
opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.LoggerProvider)
}
Expand Down
2 changes: 1 addition & 1 deletion jwt/caching_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewCachingParser(keysProvider KeysProvider) (*CachingParser, error) {
func NewCachingParserWithOpts(
keysProvider KeysProvider, opts CachingParserOpts,
) (*CachingParser, error) {
promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, "jwt_parser")
promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, metrics.SourceJWTParser)
if opts.CacheMaxEntries == 0 {
opts.CacheMaxEntries = DefaultClaimsCacheMaxEntries
}
Expand Down
22 changes: 19 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/internal/idputil"
"github.com/acronis/go-authkit/internal/metrics"
"github.com/acronis/go-authkit/jwt"
)

Expand Down Expand Up @@ -70,12 +71,14 @@ type jwtAuthHandler struct {
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
promMetrics *metrics.PrometheusMetrics
}

type jwtAuthMiddlewareOpts struct {
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
prometheusLibInstanceLabel string
}

// JWTAuthMiddlewareOption is an option for JWTAuthMiddleware.
Expand All @@ -102,6 +105,13 @@ func WithJWTAuthMiddlewareLoggerProvider(loggerProvider func(ctx context.Context
}
}

// WithJWTAuthMiddlewarePrometheusLibInstanceLabel is an option to set a label for Prometheus metrics that are used by JWTAuthMiddleware.
func WithJWTAuthMiddlewarePrometheusLibInstanceLabel(label string) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.prometheusLibInstanceLabel = label
}
}

// JWTAuthMiddleware is a middleware that does authentication
// by Access Token from the "Authorization" HTTP header of incoming request.
// errorDomain is used for error responses. It is usually the name of the service that uses the middleware,
Expand All @@ -123,6 +133,7 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM
verifyAccess: options.verifyAccess,
tokenIntrospector: options.tokenIntrospector,
loggerProvider: options.loggerProvider,
promMetrics: metrics.GetPrometheusMetrics(options.prometheusLibInstanceLabel, metrics.SourceHTTPMiddleware),
}
}
}
Expand All @@ -146,21 +157,25 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token's introspection is not needed")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded)
case errors.Is(err, idptoken.ErrTokenNotIntrospectable):
// Token is not introspectable by some reason.
// In this case, we will parse it as JWT and use it for authZ.
h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is",
log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable)
default:
logger := h.logger(reqCtx)
logger.Error("token's introspection failed", log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
} else {
if !introspectionResult.IsActive() {
h.logger(reqCtx).Warn("token was successfully introspected, but it is not active")
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
return
Expand All @@ -169,6 +184,7 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token was successfully introspected")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive)
}
}

Expand Down
32 changes: 31 additions & 1 deletion middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/internal/metrics"
"github.com/acronis/go-authkit/jwt"
)

Expand Down Expand Up @@ -122,12 +123,18 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
require.Equal(t, 1, introspector.introspectCalled)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 1)
})

t.Run("introspection is not needed", func(t *testing.T) {
Expand All @@ -139,6 +146,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
Expand All @@ -148,6 +158,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
nextIssuer, err := next.jwtClaims.GetIssuer()
require.NoError(t, err)
require.Equal(t, issuer, nextIssuer)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 1)
})

t.Run("ok, token is not introspectable", func(t *testing.T) {
Expand All @@ -159,6 +172,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
Expand All @@ -169,23 +185,31 @@ func TestJWTAuthMiddleware(t *testing.T) {
nextIssuer, err := next.jwtClaims.GetIssuer()
require.NoError(t, err)
require.Equal(t, issuer, nextIssuer)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 1)
})

t.Run("authentication failed, token is introspected but inactive", func(t *testing.T) {
const issuer = "my-idp.com"
parser := &mockJWTParser{}
introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
require.Equal(t, 1, introspector.introspectCalled)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 1)
})

t.Run("ok, token is introspected and active", func(t *testing.T) {
Expand All @@ -198,6 +222,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
Expand All @@ -208,6 +235,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
nextIssuer, err := next.jwtClaims.GetIssuer()
require.NoError(t, err)
require.Equal(t, issuer, nextIssuer)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 1)
})
}

Expand Down
Loading