diff --git a/context.go b/context.go index ad23101..dc10a12 100644 --- a/context.go +++ b/context.go @@ -15,24 +15,31 @@ func WithContext(parent context.Context, fields ...Field) context.Context { return context.WithValue(parent, ContextKeyLogFields, makeFieldStack().push(fields)) } -// PushContextFields pushes the given fields onto the logging fields stack. PushContextFields panics if the -// context has not been initialized via WithContext. +// PushContextFields pushes the given fields onto the logging fields stack. func PushContextFields(ctx context.Context, fields ...Field) { stack := getStack(ctx) + if stack == nil { + return + } stack.push(fields) } -// PopContextFields pops the last entry off of the logging fields stack. PopContextFields panics if the -// context has not been initialized via WithContext. +// PopContextFields pops the last entry off of the logging fields stack. func PopContextFields(ctx context.Context) { stack := getStack(ctx) + if stack == nil { + return + } stack.pop() } -// GetContextFields retrieves the logging `Fields` from context. GetContextFields panics if the -// context has not been initialized via WithContext. +// GetContextFields retrieves the logging `Fields` from context. GetContextFields returns an empty Fields map +// if the context has not been initialized by calling WithContext. func GetContextFields(ctx context.Context, additionalFields ...Field) Fields { stack := getStack(ctx) + if stack == nil { + return make(Fields) + } fields := stack.allFields() for _, f := range additionalFields { fields[f.Name] = f.Value @@ -43,11 +50,12 @@ func GetContextFields(ctx context.Context, additionalFields ...Field) Fields { func getStack(ctx context.Context) *fieldStack { stackObj := ctx.Value(ContextKeyLogFields) if stackObj == nil { - panic("logging fields have not been added to context yet; call WithContext") + Warn("context logging fields not initialized; call log.WithContext") + return nil } stack, ok := stackObj.(*fieldStack) if !ok { - panic("logging fields are not of the correct type") + Warn("context logging fields has incorrect type") } return stack } diff --git a/context_test.go b/context_test.go index d4fea3e..93700b8 100644 --- a/context_test.go +++ b/context_test.go @@ -31,6 +31,25 @@ var _ = Describe("Context", func() { fields = GetContextFields(ctx) g.Expect(len(fields)).To(g.Equal(1)) g.Expect(fields["foo"]).To(g.Equal(5)) + + // pop should be safe to call even beyond actual stack height + PopContextFields(ctx) + PopContextFields(ctx) + PopContextFields(ctx) + }) + + It("should safely ignore when context is not initialized", func() { + ctx := context.Background() + + fields := GetContextFields(ctx) + g.Expect(len(fields)).To(g.Equal(0)) + + PushContextFields(ctx, MakeField("foo", 0)) + PopContextFields(ctx) + + ctx = context.WithValue(ctx, ContextKeyLogFields, "wrong") + PopContextFields(ctx) + PopContextFields(ctx) }) }) })