diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index 7711bb7880..c3ae7928b2 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -228,17 +228,18 @@ func (s *StateMachine[Event, Env]) Stop() { // SendEvent sends a new event to the state machine. // // TODO(roasbeef): bool if processed? -func (s *StateMachine[Event, Env]) SendEvent(event Event) { - s.sendEvent(event) +func (s *StateMachine[Event, Env]) SendEvent(ctx context.Context, event Event) { + s.sendEvent(ctx, event) } // sendEvent sends a new event to the state machine. -func (s *StateMachine[Event, Env]) sendEvent(event Event) { +func (s *StateMachine[Event, Env]) sendEvent(ctx context.Context, event Event) { log.Debugf("FSM(%v): sending event: %v", s.cfg.Env.Name(), lnutils.SpewLogClosure(event)) select { case s.events <- event: + case <-ctx.Done(): case <-s.quit: return } @@ -262,15 +263,19 @@ func (s *StateMachine[Event, Env]) Name() string { // message can be mapped using the default message mapper, then true is // returned indicating that the message was processed. Otherwise, false is // returned. -func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool { - return s.sendMessage(msg) +func (s *StateMachine[Event, Env]) SendMessage(ctx context.Context, + msg lnwire.Message) bool { + + return s.sendMessage(ctx, msg) } // sendMessage attempts to send a wire message to the state machine. If the // message can be mapped using the default message mapper, then true is // returned indicating that the message was processed. Otherwise, false is // returned. -func (s *StateMachine[Event, Env]) sendMessage(msg lnwire.Message) bool { +func (s *StateMachine[Event, Env]) sendMessage(ctx context.Context, + msg lnwire.Message) bool { + // If we have no message mapper, then return false as we can't process // this message. if !s.cfg.MsgMapper.IsSome() { @@ -289,7 +294,7 @@ func (s *StateMachine[Event, Env]) sendMessage(msg lnwire.Message) bool { event := mapper.MapMsg(msg) event.WhenSome(func(event Event) { - s.sendEvent(event) + s.sendEvent(ctx, event) processed = true }) @@ -374,7 +379,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context, s.cfg.Env.Name(), lnutils.SpewLogClosure(event)) - s.sendEvent(event) + s.sendEvent(ctx, event) }, ) @@ -483,7 +488,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context, postSpend := daemonEvent.PostSpendEvent postSpend.WhenSome(func(f SpendMapper[Event]) { //nolint:ll customEvent := f(spend) - s.sendEvent(customEvent) + s.sendEvent(ctx, customEvent) }) return @@ -527,7 +532,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context, // dispatchAfterRecv w/ above postConf := daemonEvent.PostConfEvent postConf.WhenSome(func(e Event) { - s.sendEvent(e) + s.sendEvent(ctx, e) }) return diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index 6d8bc7e1d7..9b7b8b1211 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -307,7 +307,7 @@ func TestStateMachineInternalEvents(t *testing.T) { // For this transition, we'll send in the emitInternal event, which'll // send us back to the starting event, but emit an internal event. - stateMachine.SendEvent(&emitInternal{}) + stateMachine.SendEvent(ctx, &emitInternal{}) // We'll now also assert the path we took to get here to ensure the // internal events were processed. @@ -367,7 +367,7 @@ func TestStateMachineDaemonEvents(t *testing.T) { // We'll start off by sending in the daemon event, which'll trigger the // state machine to execute the series of daemon events. - stateMachine.SendEvent(&daemonEvents{}) + stateMachine.SendEvent(ctx, &daemonEvents{}) // We should transition back to the starting state now, after we // started from the very same state. @@ -454,7 +454,8 @@ func TestStateMachineMsgMapper(t *testing.T) { // Next, we'll attempt to send the wire message into the state machine. // We should transition to the final state. - require.True(t, stateMachine.SendMessage(wireError)) + require.True(t, stateMachine.SendMessage(ctx, + wireError)) // We should transition to the final state. expectedStates := []State[dummyEvents, *dummyEnv]{