Skip to content

Commit

Permalink
fn: Remove ctx from GoroutineManager constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ellemouton committed Dec 10, 2024
1 parent 5659c01 commit 66e8803
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 159 deletions.
145 changes: 112 additions & 33 deletions fn/goroutine_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,122 @@ package fn
import (
"context"
"sync"
"sync/atomic"
)

// GoroutineManager is used to launch goroutines until context expires or the
// manager is stopped. The Stop method blocks until all started goroutines stop.
type GoroutineManager struct {
wg sync.WaitGroup
mu sync.Mutex
ctx context.Context
cancel func()
// id is used to generate unique ids for each goroutine.
id atomic.Uint32

// cancelFns is a map of cancel functions that can be used to cancel the
// context of a goroutine. The mutex must be held when accessing this
// map. The key is the id of the goroutine.
cancelFns map[uint32]context.CancelFunc

mu sync.Mutex

stopped sync.Once
quit chan struct{}
wg sync.WaitGroup
}

// NewGoroutineManager constructs and returns a new instance of
// GoroutineManager.
func NewGoroutineManager(ctx context.Context) *GoroutineManager {
ctx, cancel := context.WithCancel(ctx)

func NewGoroutineManager() *GoroutineManager {
return &GoroutineManager{
ctx: ctx,
cancel: cancel,
cancelFns: make(map[uint32]context.CancelFunc),
quit: make(chan struct{}),
}
}

// addCancelFn adds a context cancel function to the manager and returns an id
// that can can be used to cancel the context later on when the goroutine is
// done.
func (g *GoroutineManager) addCancelFn(cancel context.CancelFunc) uint32 {
g.mu.Lock()
defer g.mu.Unlock()

id := g.id.Add(1)
g.cancelFns[id] = cancel

return id
}

// cancel cancels the context associated with the passed id.
func (g *GoroutineManager) cancel(id uint32) {
g.mu.Lock()
defer g.mu.Unlock()

g.cancelUnsafe(id)
}

// cancelUnsafe cancels the context associated with the passed id without
// acquiring the mutex.
func (g *GoroutineManager) cancelUnsafe(id uint32) {
fn, ok := g.cancelFns[id]
if !ok {
return
}

fn()

delete(g.cancelFns, id)
}

// Go tries to start a new goroutine and returns a boolean indicating its
// success. It fails iff the goroutine manager is stopping or its context passed
// to NewGoroutineManager has expired.
func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
// Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race
// condition, since it is not clear should Wait() block or not. This
// success. It returns true if the goroutine was successfully created and false
// otherwise. A goroutine will fail to be created iff the goroutine manager is
// stopping or the passed context has already expired. The passed call-back
// function must exit if the passed context expires.
func (g *GoroutineManager) Go(ctx context.Context,
f func(ctx context.Context)) bool {

// Derive a cancellable context from the passed context and store its
// cancel function in the manager. The context will be cancelled when
// either the parent context is cancelled or the quit channel is closed
// which will call the stored cancel function.
ctx, cancel := context.WithCancel(ctx)
id := g.addCancelFn(cancel)

// Calling wg.Add(1) and wg.Wait() when the wg's counter is 0 is a race
// condition, since it is not clear if Wait() should block or not. This
// kind of race condition is detected by Go runtime and results in a
// crash if running with `-race`. To prevent this, whole Go method is
// protected with a mutex. The call to wg.Wait() inside Stop() can still
// run in parallel with Go, but in that case g.ctx is in expired state,
// because cancel() was called in Stop, so Go returns before wg.Add(1)
// call.
// crash if running with `-race`. To prevent this, we protect the calls
// to wg.Add(1) and wg.Wait() with a mutex. If we block here because
// Stop is running first, then Stop will cancel the quit channel which
// will cause the context to be cancelled, and we will exit before
// calling wg.Add(1). If we grab the mutex here before Stop does, then
// Stop will block until after we call wg.Add(1).
g.mu.Lock()
defer g.mu.Unlock()

if g.ctx.Err() != nil {
// Before continuing to start the goroutine, we need to check if the
// context has already expired. This could be the case if the parent
// context has already expired or if Stop has been called.
if ctx.Err() != nil {
g.cancelUnsafe(id)

return false
}

// Ensure that the goroutine is not started if the manager has stopped.
select {
case <-g.quit:
g.cancelUnsafe(id)
return false
default:
}

g.wg.Add(1)
go func() {
defer g.wg.Done()
f(g.ctx)
defer func() {
g.cancel(id)
g.wg.Done()
}()

f(ctx)
}()

return true
Expand All @@ -56,20 +127,28 @@ func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
// Stop prevents new goroutines from being added and waits for all running
// goroutines to finish.
func (g *GoroutineManager) Stop() {
g.mu.Lock()
g.cancel()
g.mu.Unlock()

// Wait for all goroutines to finish. Note that this wg.Wait() call is
// safe, since it can't run in parallel with wg.Add(1) call in Go, since
// we just cancelled the context and even if Go call starts running here
// after acquiring the mutex, it would see that the context has expired
// and return false instead of calling wg.Add(1).
g.wg.Wait()
g.stopped.Do(func() {
// Closing the quit channel will prevent any new goroutines from
// starting.
g.mu.Lock()
close(g.quit)
for _, cancel := range g.cancelFns {
cancel()
}
g.mu.Unlock()

// Wait for all goroutines to finish. Note that this wg.Wait()
// call is safe, since it can't run in parallel with wg.Add(1)
// call in Go, since we just cancelled the context and even if
// Go call starts running here after acquiring the mutex, it
// would see that the context has expired and return false
// instead of calling wg.Add(1).
g.wg.Wait()
})
}

// Done returns a channel which is closed when either the context passed to
// NewGoroutineManager expires or when Stop is called.
func (g *GoroutineManager) Done() <-chan struct{} {
return g.ctx.Done()
return g.quit
}
Loading

0 comments on commit 66e8803

Please sign in to comment.