diff --git a/build/version.go b/build/version.go index ef0627529b..0fcd5764f3 100644 --- a/build/version.go +++ b/build/version.go @@ -47,7 +47,7 @@ const ( // AppPreRelease MUST only contain characters from semanticAlphabet per // the semantic versioning spec. - AppPreRelease = "beta.rc1" + AppPreRelease = "beta.rc2" ) func init() { diff --git a/config_builder.go b/config_builder.go index 1c3a842ef1..7336a42b3b 100644 --- a/config_builder.go +++ b/config_builder.go @@ -35,6 +35,7 @@ import ( "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -46,7 +47,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/msgmux" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" @@ -165,7 +165,7 @@ type AuxComponents struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[routing.TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] // MsgRouter is an optional message router that if set will be used in // place of a new blank default message router. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 46af3e5aeb..3a525507a8 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -482,6 +482,20 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { return err } + c.wg.Add(1) + go c.channelAttendant(bestHeight, state.commitSet) + + return nil +} + +// progressStateMachineAfterRestart attempts to progress the state machine +// after a restart. This makes sure that if the state transition failed, we +// will try to progress the state machine again. Moreover it will relaunch +// resolvers if the channel is still in the pending close state and has not +// been fully resolved yet. +func (c *ChannelArbitrator) progressStateMachineAfterRestart(bestHeight int32, + commitSet *CommitSet) error { + // If the channel has been marked pending close in the database, and we // haven't transitioned the state machine to StateContractClosed (or a // succeeding state), then a state transition most likely failed. We'll @@ -527,7 +541,7 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // on-chain state, and our set of active contracts. startingState := c.state nextState, _, err := c.advanceState( - triggerHeight, trigger, state.commitSet, + triggerHeight, trigger, commitSet, ) if err != nil { switch err { @@ -564,14 +578,12 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // receive a chain event from the chain watcher that the // commitment has been confirmed on chain, and before we // advance our state step, we call InsertConfirmedCommitSet. - err := c.relaunchResolvers(state.commitSet, triggerHeight) + err := c.relaunchResolvers(commitSet, triggerHeight) if err != nil { return err } } - c.wg.Add(1) - go c.channelAttendant(bestHeight) return nil } @@ -2775,13 +2787,28 @@ func (c *ChannelArbitrator) updateActiveHTLCs() { // Nursery for incubation, and ultimate sweeping. // // NOTE: This MUST be run as a goroutine. -func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { +// +//nolint:funlen +func (c *ChannelArbitrator) channelAttendant(bestHeight int32, + commitSet *CommitSet) { // TODO(roasbeef): tell top chain arb we're done defer func() { c.wg.Done() }() + err := c.progressStateMachineAfterRestart(bestHeight, commitSet) + if err != nil { + // In case of an error, we return early but we do not shutdown + // LND, because there might be other channels that still can be + // resolved and we don't want to interfere with that. + // We continue to run the channel attendant in case the channel + // closes via other means for example the remote pary force + // closes the channel. So we log the error and continue. + log.Errorf("Unable to progress state machine after "+ + "restart: %v", err) + } + for { select { diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 7f8c4b087f..bc825959a3 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -1043,10 +1044,19 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // Post restart, it should be the case that our resolver was properly // supplemented, and we only have a single resolver in the final set. - if len(chanArb.activeResolvers) != 1 { - t.Fatalf("expected single resolver, instead got: %v", - len(chanArb.activeResolvers)) - } + // The resolvers are added concurrently so we need to wait here. + err = wait.NoError(func() error { + chanArb.activeResolversLock.Lock() + defer chanArb.activeResolversLock.Unlock() + + if len(chanArb.activeResolvers) != 1 { + return fmt.Errorf("expected single resolver, instead "+ + "got: %v", len(chanArb.activeResolvers)) + } + + return nil + }, defaultTimeout) + require.NoError(t, err) // We'll now examine the in-memory state of the active resolvers to // ensure t hey were populated properly. @@ -2884,9 +2894,12 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) { { name: "Commitment is rejected with an " + "unmatched error", - broadcastErr: fmt.Errorf("Reject Commitment Tx"), - expectedState: StateBroadcastCommit, - expectedStartup: false, + broadcastErr: fmt.Errorf("Reject Commitment Tx"), + expectedState: StateBroadcastCommit, + // We should still be able to start up since we other + // channels might be closing as well and we should + // resolve the contracts. + expectedStartup: true, }, // We started after the DLP was triggered, and try to force diff --git a/docs/release-notes/release-notes-0.18.4.md b/docs/release-notes/release-notes-0.18.4.md index 56270c0a77..f13d57a87b 100644 --- a/docs/release-notes/release-notes-0.18.4.md +++ b/docs/release-notes/release-notes-0.18.4.md @@ -26,6 +26,13 @@ * [Make the contract resolutions for the channel arbitrator optional]( https://github.com/lightningnetwork/lnd/pull/9253). +* [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9324) to prevent + potential deadlocks when LND depends on external components (e.g. aux + components, hooks). + +* [Make sure blinded payment failures are handled correctly in the mission +controller](https://github.com/lightningnetwork/lnd/pull/9316). + # New Features The main channel state machine and database now allow for processing and storing @@ -121,4 +128,5 @@ types in a series of changes: * George Tsagkarelis * Olaoluwa Osuntokun * Oliver Gugger +* Ziggie diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 72143bc45a..124443c591 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -194,6 +194,11 @@ const ( Outgoing LinkDirection = true ) +// OptionalBandwidth is a type alias for the result of a bandwidth query that +// may return a bandwidth value or fn.None if the bandwidth is not available or +// not applicable. +type OptionalBandwidth = fn.Option[lnwire.MilliSatoshi] + // ChannelLink is an interface which represents the subsystem for managing the // incoming htlc requests, applying the changes to the channel, and also // propagating/forwarding it to htlc switch. @@ -255,10 +260,10 @@ type ChannelLink interface { // in order to signal to the source of the HTLC, the policy consistency // issue. CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, - amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, scid lnwire.ShortChannelID) *LinkError + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, scid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError // CheckHtlcTransit should return a nil error if the passed HTLC details // satisfy the current channel policy. Otherwise, a LinkError with a @@ -266,14 +271,15 @@ type ChannelLink interface { // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, - timeout uint32, heightNow uint32) *LinkError + timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError // Stats return the statistics of channel link. Number of updates, // total sent/received milli-satoshis. Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) - // Peer returns the serialized public key of remote peer with which we - // have the channel link opened. + // PeerPubKey returns the serialized public key of remote peer with + // which we have the channel link opened. PeerPubKey() [33]byte // AttachMailBox delivers an active MailBox to the link. The MailBox may @@ -290,9 +296,18 @@ type ChannelLink interface { // commitment of the channel that this link is associated with. CommitmentCustomBlob() fn.Option[tlv.Blob] - // Start/Stop are used to initiate the start/stop of the channel link - // functioning. + // AuxBandwidth returns the bandwidth that can be used for a channel, + // expressed in milli-satoshi. This might be different from the regular + // BTC bandwidth for custom channels. This will always return fn.None() + // for a regular (non-custom) channel. + AuxBandwidth(amount lnwire.MilliSatoshi, cid lnwire.ShortChannelID, + htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] + + // Start starts the channel link. Start() error + + // Stop requests the channel link to be shut down. Stop() } @@ -428,7 +443,7 @@ type htlcNotifier interface { NotifyForwardingEvent(key HtlcKey, info HtlcInfo, eventType HtlcEventType) - // NotifyIncomingLinkFailEvent notifies that a htlc has failed on our + // NotifyLinkFailEvent notifies that a htlc has failed on our // incoming link. It takes an isReceive bool to differentiate between // our node's receives and forwards. NotifyLinkFailEvent(key HtlcKey, info HtlcInfo, @@ -449,3 +464,36 @@ type htlcNotifier interface { NotifyFinalHtlcEvent(key models.CircuitKey, info channeldb.FinalHtlcInfo) } + +// AuxHtlcModifier is an interface that allows the sender to modify the outgoing +// HTLC of a payment by changing the amount or the wire message tlv records. +type AuxHtlcModifier interface { + // ProduceHtlcExtraData is a function that, based on the previous extra + // data blob of an HTLC, may produce a different blob or modify the + // amount of bitcoin this htlc should carry. + ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, + lnwire.CustomRecords, error) +} + +// AuxTrafficShaper is an interface that allows the sender to determine if a +// payment should be carried by a channel based on the TLV records that may be +// present in the `update_add_htlc` message or the channel commitment itself. +type AuxTrafficShaper interface { + AuxHtlcModifier + + // ShouldHandleTraffic is called in order to check if the channel + // identified by the provided channel ID may have external mechanisms + // that would allow it to carry out the payment. + ShouldHandleTraffic(cid lnwire.ShortChannelID, + fundingBlob fn.Option[tlv.Blob]) (bool, error) + + // PaymentBandwidth returns the available bandwidth for a custom channel + // decided by the given channel aux blob and HTLC blob. A return value + // of 0 means there is no bandwidth available. To find out if a channel + // is a custom channel that should be handled by the traffic shaper, the + // ShouldHandleTraffic method should be called first. + PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], + linkBandwidth, + htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 2e2f104af7..008849dee8 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -284,6 +284,10 @@ type ChannelLinkConfig struct { // MaxFeeExposure is the threshold in milli-satoshis after which we'll // restrict the flow of HTLCs and fee updates. MaxFeeExposure lnwire.MilliSatoshi + + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of the link. + AuxTrafficShaper fn.Option[AuxTrafficShaper] } // channelLink is the service which drives a channel's commitment update @@ -3021,11 +3025,11 @@ func (l *channelLink) UpdateForwardingPolicy( // issue. // // NOTE: Part of the ChannelLink interface. -func (l *channelLink) CheckHtlcForward(payHash [32]byte, - incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { +func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt, + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3074,7 +3078,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Check whether the outgoing htlc satisfies the channel policy. err := l.canSendHtlc( policy, payHash, amtToForward, outgoingTimeout, heightNow, - originalScid, + originalScid, customRecords, ) if err != nil { return err @@ -3110,8 +3114,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. func (l *channelLink) CheckHtlcTransit(payHash [32]byte, - amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3122,6 +3126,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // to occur. return l.canSendHtlc( policy, payHash, amt, timeout, heightNow, hop.Source, + customRecords, ) } @@ -3129,7 +3134,8 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // the channel's amount and time lock constraints. func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { // As our first sanity check, we'll ensure that the passed HTLC isn't // too small for the next hop. If so, then we'll cancel the HTLC @@ -3187,8 +3193,38 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return NewLinkError(&lnwire.FailExpiryTooFar{}) } + // We now check the available bandwidth to see if this HTLC can be + // forwarded. + availableBandwidth := l.Bandwidth() + auxBandwidth, err := fn.MapOptionZ( + l.cfg.AuxTrafficShaper, + func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + var htlcBlob fn.Option[tlv.Blob] + blob, err := customRecords.Serialize() + if err != nil { + return fn.Err[OptionalBandwidth]( + fmt.Errorf("unable to serialize "+ + "custom records: %w", err)) + } + + if len(blob) > 0 { + htlcBlob = fn.Some(blob) + } + + return l.AuxBandwidth(amt, originalScid, htlcBlob, ts) + }, + ).Unpack() + if err != nil { + l.log.Errorf("Unable to determine aux bandwidth: %v", err) + return NewLinkError(&lnwire.FailTemporaryNodeFailure{}) + } + + auxBandwidth.WhenSome(func(bandwidth lnwire.MilliSatoshi) { + availableBandwidth = bandwidth + }) + // Check to see if there is enough balance in this channel. - if amt > l.Bandwidth() { + if amt > availableBandwidth { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { @@ -3203,6 +3239,48 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return nil } +// AuxBandwidth returns the bandwidth that can be used for a channel, expressed +// in milli-satoshi. This might be different from the regular BTC bandwidth for +// custom channels. This will always return fn.None() for a regular (non-custom) +// channel. +func (l *channelLink) AuxBandwidth(amount lnwire.MilliSatoshi, + cid lnwire.ShortChannelID, htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + unknownBandwidth := fn.None[lnwire.MilliSatoshi]() + + fundingBlob := l.FundingCustomBlob() + shouldHandle, err := ts.ShouldHandleTraffic(cid, fundingBlob) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("traffic shaper "+ + "failed to decide whether to handle traffic: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper is handling "+ + "traffic: %v", cid, shouldHandle) + + // If this channel isn't handled by the aux traffic shaper, we'll return + // early. + if !shouldHandle { + return fn.Ok(unknownBandwidth) + } + + // Ask for a specific bandwidth to be used for the channel. + commitmentBlob := l.CommitmentCustomBlob() + auxBandwidth, err := ts.PaymentBandwidth( + htlcBlob, commitmentBlob, l.Bandwidth(), amount, + ) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("failed to get "+ + "bandwidth from external traffic shaper: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper reported available "+ + "bandwidth: %v", cid, auxBandwidth) + + return fn.Ok(fn.Some(auxBandwidth)) +} + // Stats returns the statistics of channel link. // // NOTE: Part of the ChannelLink interface. diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 574f3a6778..764df8c30f 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6240,9 +6240,9 @@ func TestCheckHtlcForward(t *testing.T) { var hash [32]byte t.Run("satisfied", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if result != nil { t.Fatalf("expected policy to be satisfied") @@ -6250,9 +6250,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("below minhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 100, 50, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 100, 50, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok { t.Fatalf("expected FailAmountBelowMinimum failure code") @@ -6260,9 +6260,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("above maxhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1200, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1200, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok { t.Fatalf("expected FailTemporaryChannelFailure failure code") @@ -6270,9 +6270,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("insufficient fee", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1005, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1005, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected FailFeeInsufficient failure code") @@ -6285,17 +6285,17 @@ func TestCheckHtlcForward(t *testing.T) { t.Parallel() result := link.CheckHtlcForward( - hash, 100005, 100000, 200, - 150, models.InboundFee{}, 0, lnwire.ShortChannelID{}, + hash, 100005, 100000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient) require.True(t, ok, "expected FailFeeInsufficient failure code") }) t.Run("expiry too soon", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 190, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 190, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok { t.Fatalf("expected FailExpiryTooSoon failure code") @@ -6303,9 +6303,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("incorrect cltv expiry", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 190, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 190, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok { t.Fatalf("expected FailIncorrectCltvExpiry failure code") @@ -6315,9 +6315,9 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) { // Check that expiry isn't too far in the future. - result := link.CheckHtlcForward(hash, 1500, 1000, - 10200, 10100, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok { t.Fatalf("expected FailExpiryTooFar failure code") @@ -6327,9 +6327,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee satisfied", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000, - 200, 150, models.InboundFee{Base: -2, Rate: -1_000}, - 0, lnwire.ShortChannelID{}) + result := link.CheckHtlcForward( + hash, 1000+10-2-1, 1000, 200, 150, + models.InboundFee{Base: -2, Rate: -1_000}, + 0, lnwire.ShortChannelID{}, nil, + ) if result != nil { t.Fatalf("expected policy to be satisfied") } @@ -6338,9 +6340,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee insufficient", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000, + result := link.CheckHtlcForward( + hash, 1000+10-10-101-1, 1000, 200, 150, models.InboundFee{Base: -10, Rate: -100_000}, - 0, lnwire.ShortChannelID{}) + 0, lnwire.ShortChannelID{}, nil, + ) msg := result.WireMessage() if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 750bdf784f..ab1f204894 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -845,14 +845,14 @@ func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) { } func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32, - lnwire.ShortChannelID) *LinkError { + lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError { return f.checkHtlcForwardResult } func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + heightNow uint32, _ lnwire.CustomRecords) *LinkError { return f.checkHtlcTransitResult } @@ -959,6 +959,17 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { return fn.None[tlv.Blob]() } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi, + lnwire.ShortChannelID, + fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + return fn.Ok(fn.None[lnwire.MilliSatoshi]()) +} + var _ ChannelLink = (*mockChannelLink)(nil) func newDB() (*channeldb.DB, func(), error) { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 35eb4a6ef4..b5a4ab5b7a 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -916,6 +916,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( currentHeight := atomic.LoadUint32(&s.bestHeight) htlcErr := link.CheckHtlcTransit( htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight, + htlc.CustomRecords, ) if htlcErr != nil { log.Errorf("Link %v policy for local forward not "+ @@ -2886,10 +2887,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, failure = link.CheckHtlcForward( htlc.PaymentHash, packet.incomingAmount, packet.amount, packet.incomingTimeout, - packet.outgoingTimeout, - packet.inboundFee, - currentHeight, - packet.originalOutgoingChanID, + packet.outgoingTimeout, packet.inboundFee, + currentHeight, packet.originalOutgoingChanID, + htlc.CustomRecords, ) } diff --git a/input/script_utils.go b/input/script_utils.go index 91ca55292f..a50cb932b6 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -29,9 +29,9 @@ var ( SequenceLockTimeSeconds = uint32(1 << 22) ) -// mustParsePubKey parses a hex encoded public key string into a public key and +// MustParsePubKey parses a hex encoded public key string into a public key and // panic if parsing fails. -func mustParsePubKey(pubStr string) btcec.PublicKey { +func MustParsePubKey(pubStr string) btcec.PublicKey { pubBytes, err := hex.DecodeString(pubStr) if err != nil { panic(err) @@ -55,7 +55,7 @@ var ( // https://github.com/lightninglabs/lightning-node-connect/tree/ // master/mailbox/numsgen, with the seed phrase "Lightning Simple // Taproot". - TaprootNUMSKey = mustParsePubKey(TaprootNUMSHex) + TaprootNUMSKey = MustParsePubKey(TaprootNUMSHex) ) // Signature is an interface for objects that can populate signatures during diff --git a/peer/brontide.go b/peer/brontide.go index fa42f13584..2d7540992c 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -400,6 +400,10 @@ type Config struct { // way contracts are resolved. AuxResolver fn.Option[lnwallet.AuxContractResolver] + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of peer links. + AuxTrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -1319,6 +1323,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, PreviouslySentShutdown: shutdownMsg, DisallowRouteBlinding: p.cfg.DisallowRouteBlinding, MaxFeeExposure: p.cfg.MaxFeeExposure, + AuxTrafficShaper: p.cfg.AuxTrafficShaper, } // Before adding our new link, purge the switch of any pending or live diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 3b80dadc7c..6ecd86765e 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -29,39 +29,6 @@ type bandwidthHints interface { firstHopCustomBlob() fn.Option[tlv.Blob] } -// TlvTrafficShaper is an interface that allows the sender to determine if a -// payment should be carried by a channel based on the TLV records that may be -// present in the `update_add_htlc` message or the channel commitment itself. -type TlvTrafficShaper interface { - AuxHtlcModifier - - // ShouldHandleTraffic is called in order to check if the channel - // identified by the provided channel ID may have external mechanisms - // that would allow it to carry out the payment. - ShouldHandleTraffic(cid lnwire.ShortChannelID, - fundingBlob fn.Option[tlv.Blob]) (bool, error) - - // PaymentBandwidth returns the available bandwidth for a custom channel - // decided by the given channel aux blob and HTLC blob. A return value - // of 0 means there is no bandwidth available. To find out if a channel - // is a custom channel that should be handled by the traffic shaper, the - // HandleTraffic method should be called first. - PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], - linkBandwidth, - htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) -} - -// AuxHtlcModifier is an interface that allows the sender to modify the outgoing -// HTLC of a payment by changing the amount or the wire message tlv records. -type AuxHtlcModifier interface { - // ProduceHtlcExtraData is a function that, based on the previous extra - // data blob of an HTLC, may produce a different blob or modify the - // amount of bitcoin this htlc should carry. - ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, - htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, - lnwire.CustomRecords, error) -} - // getLinkQuery is the function signature used to lookup a link. type getLinkQuery func(lnwire.ShortChannelID) ( htlcswitch.ChannelLink, error) @@ -73,7 +40,7 @@ type bandwidthManager struct { getLink getLinkQuery localChans map[lnwire.ShortChannelID]struct{} firstHopBlob fn.Option[tlv.Blob] - trafficShaper fn.Option[TlvTrafficShaper] + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // newBandwidthManager creates a bandwidth manager for the source node provided @@ -84,13 +51,14 @@ type bandwidthManager struct { // that are inactive, or just don't have enough bandwidth to carry the payment. func newBandwidthManager(graph Graph, sourceNode route.Vertex, linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { + ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, + error) { manager := &bandwidthManager{ getLink: linkQuery, localChans: make(map[lnwire.ShortChannelID]struct{}), firstHopBlob: firstHopBlob, - trafficShaper: trafficShaper, + trafficShaper: ts, } // First, we'll collect the set of outbound edges from the target @@ -166,44 +134,15 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, result, err := fn.MapOptionZ( b.trafficShaper, - func(ts TlvTrafficShaper) fn.Result[bandwidthResult] { - fundingBlob := link.FundingCustomBlob() - shouldHandle, err := ts.ShouldHandleTraffic( - cid, fundingBlob, - ) - if err != nil { - return bandwidthErr(fmt.Errorf("traffic "+ - "shaper failed to decide whether to "+ - "handle traffic: %w", err)) - } - - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper is handling traffic: %v", cid, - shouldHandle) - - // If this channel isn't handled by the external traffic - // shaper, we'll return early. - if !shouldHandle { - return fn.Ok(bandwidthResult{}) - } - - // Ask for a specific bandwidth to be used for the - // channel. - commitmentBlob := link.CommitmentCustomBlob() - auxBandwidth, err := ts.PaymentBandwidth( - b.firstHopBlob, commitmentBlob, linkBandwidth, - amount, - ) + func(s htlcswitch.AuxTrafficShaper) fn.Result[bandwidthResult] { + auxBandwidth, err := link.AuxBandwidth( + amount, cid, b.firstHopBlob, s, + ).Unpack() if err != nil { return bandwidthErr(fmt.Errorf("failed to get "+ - "bandwidth from external traffic "+ - "shaper: %w", err)) + "auxiliary bandwidth: %w", err)) } - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper reported available bandwidth: %v", cid, - auxBandwidth) - // We don't know the actual HTLC amount that will be // sent using the custom channel. But we'll still want // to make sure we can add another HTLC, using the @@ -213,7 +152,7 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, // the max number of HTLCs on the channel. A proper // balance check is done elsewhere. return fn.Ok(bandwidthResult{ - bandwidth: fn.Some(auxBandwidth), + bandwidth: auxBandwidth, htlcAmount: fn.Some[lnwire.MilliSatoshi](0), }) }, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7ec..28b1dfb1ab 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -118,7 +118,9 @@ func TestBandwidthManager(t *testing.T) { m, err := newBandwidthManager( g, sourceNode.pubkey, testCase.linkQuery, fn.None[[]byte](), - fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), ) require.NoError(t, err) diff --git a/routing/blinding.go b/routing/blinding.go index 270f998d9f..e08b7ea833 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -1,17 +1,27 @@ package routing import ( + "bytes" "errors" "fmt" "github.com/btcsuite/btcd/btcec/v2" + "github.com/decred/dcrd/dcrec/secp256k1/v4" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb/models" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) +// BlindedPathNUMSHex is the hex encoded version of the blinded path target +// NUMs key (in compressed format) which has no known private key. +// This was generated using the following script: +// https://github.com/lightninglabs/lightning-node-connect/tree/master/ +// mailbox/numsgen, with the seed phrase "Lightning Blinded Path". +const BlindedPathNUMSHex = "02667a98ef82ecb522f803b17a74f14508a48b25258f9831" + + "dd6e95f5e299dfd54e" + var ( // ErrNoBlindedPath is returned when the blinded path in a blinded // payment is missing. @@ -25,6 +35,14 @@ var ( // ErrHTLCRestrictions is returned when a blinded path has invalid // HTLC maximum and minimum values. ErrHTLCRestrictions = errors.New("invalid htlc minimum and maximum") + + // BlindedPathNUMSKey is a NUMS key (nothing up my sleeves number) that + // has no known private key. + BlindedPathNUMSKey = input.MustParsePubKey(BlindedPathNUMSHex) + + // CompressedBlindedPathNUMSKey is the compressed version of the + // BlindedPathNUMSKey. + CompressedBlindedPathNUMSKey = BlindedPathNUMSKey.SerializeCompressed() ) // BlindedPaymentPathSet groups the data we need to handle sending to a set of @@ -70,7 +88,9 @@ type BlindedPaymentPathSet struct { } // NewBlindedPaymentPathSet constructs a new BlindedPaymentPathSet from a set of -// BlindedPayments. +// BlindedPayments. For blinded paths which have more than one single hop a +// dummy hop via a NUMS key is appeneded to allow for MPP path finding via +// multiple blinded paths. func NewBlindedPaymentPathSet(paths []*BlindedPayment) (*BlindedPaymentPathSet, error) { @@ -103,36 +123,53 @@ func NewBlindedPaymentPathSet(paths []*BlindedPayment) (*BlindedPaymentPathSet, } } - // Derive an ephemeral target priv key that will be injected into each - // blinded path final hop. - targetPriv, err := btcec.NewPrivateKey() - if err != nil { - return nil, err + // Deep copy the paths to avoid mutating the original paths. + pathSet := make([]*BlindedPayment, len(paths)) + for i, path := range paths { + pathSet[i] = path.deepCopy() } - targetPub := targetPriv.PubKey() - var ( - pathSet = paths - finalCLTVDelta uint16 - ) - // If any provided blinded path only has a single hop (ie, the - // destination node is also the introduction node), then we discard all - // other paths since we know the real pub key of the destination node. - // We also then set the final CLTV delta to the path's delta since - // there are no other edge hints that will account for it. For a single - // hop path, there is also no need for the pseudo target pub key - // replacement, so our target pub key in this case just remains the - // real introduction node ID. - for _, path := range paths { - if len(path.BlindedPath.BlindedHops) != 1 { - continue + // For blinded paths we use the NUMS key as a target if the blinded + // path has more hops than just the introduction node. + targetPub := &BlindedPathNUMSKey + + var finalCLTVDelta uint16 + + // In case the paths do NOT include a single hop route we append a + // dummy hop via a NUMS key to allow for MPP path finding via multiple + // blinded paths. A unified target is needed to use all blinded paths + // during the payment lifecycle. A dummy hop is solely added for the + // path finding process and is removed after the path is found. This + // ensures that we still populate the mission control with the correct + // data and also respect these mc entries when looking for a path. + for _, path := range pathSet { + pathLength := len(path.BlindedPath.BlindedHops) + + // If any provided blinded path only has a single hop (ie, the + // destination node is also the introduction node), then we + // discard all other paths since we know the real pub key of the + // destination node. We also then set the final CLTV delta to + // the path's delta since there are no other edge hints that + // will account for it. + if pathLength == 1 { + pathSet = []*BlindedPayment{path} + finalCLTVDelta = path.CltvExpiryDelta + targetPub = path.BlindedPath.IntroductionPoint + + break } - pathSet = []*BlindedPayment{path} - finalCLTVDelta = path.CltvExpiryDelta - targetPub = path.BlindedPath.IntroductionPoint - - break + lastHop := path.BlindedPath.BlindedHops[pathLength-1] + path.BlindedPath.BlindedHops = append( + path.BlindedPath.BlindedHops, + &sphinx.BlindedHopInfo{ + BlindedNodePub: &BlindedPathNUMSKey, + // We add the last hop's cipher text so that + // the payload size of the final hop is equal + // to the real last hop. + CipherText: lastHop.CipherText, + }, + ) } return &BlindedPaymentPathSet{ @@ -198,21 +235,33 @@ func (s *BlindedPaymentPathSet) FinalCLTVDelta() uint16 { // LargestLastHopPayloadPath returns the BlindedPayment in the set that has the // largest last-hop payload. This is to be used for onion size estimation in // path finding. -func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() *BlindedPayment { +func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() (*BlindedPayment, + error) { + var ( largestPath *BlindedPayment currentMax int ) + + if len(s.paths) == 0 { + return nil, fmt.Errorf("no blinded paths in the set") + } + + // We set the largest path to make sure we always return a path even + // if the cipher text is empty. + largestPath = s.paths[0] + for _, path := range s.paths { numHops := len(path.BlindedPath.BlindedHops) lastHop := path.BlindedPath.BlindedHops[numHops-1] if len(lastHop.CipherText) > currentMax { largestPath = path + currentMax = len(lastHop.CipherText) } } - return largestPath + return largestPath, nil } // ToRouteHints converts the blinded path payment set into a RouteHints map so @@ -222,7 +271,7 @@ func (s *BlindedPaymentPathSet) ToRouteHints() (RouteHints, error) { hints := make(RouteHints) for _, path := range s.paths { - pathHints, err := path.toRouteHints(fn.Some(s.targetPubKey)) + pathHints, err := path.toRouteHints() if err != nil { return nil, err } @@ -239,6 +288,12 @@ func (s *BlindedPaymentPathSet) ToRouteHints() (RouteHints, error) { return hints, nil } +// IsBlindedRouteNUMSTargetKey returns true if the given public key is the +// NUMS key used as a target for blinded path final hops. +func IsBlindedRouteNUMSTargetKey(pk []byte) bool { + return bytes.Equal(pk, CompressedBlindedPathNUMSKey) +} + // BlindedPayment provides the path and payment parameters required to send a // payment along a blinded path. type BlindedPayment struct { @@ -291,6 +346,22 @@ func (b *BlindedPayment) Validate() error { b.HtlcMaximum, b.HtlcMinimum) } + for _, hop := range b.BlindedPath.BlindedHops { + // The first hop of the blinded path does not necessarily have + // blinded node pub key because it is the introduction point. + if hop.BlindedNodePub == nil { + continue + } + + if IsBlindedRouteNUMSTargetKey( + hop.BlindedNodePub.SerializeCompressed(), + ) { + + return fmt.Errorf("blinded path cannot include NUMS "+ + "key: %s", BlindedPathNUMSHex) + } + } + return nil } @@ -301,11 +372,8 @@ func (b *BlindedPayment) Validate() error { // effectively the final_cltv_delta for the receiving introduction node). In // the case of multiple blinded hops, CLTV delta is fully accounted for in the // hints (both for intermediate hops and the final_cltv_delta for the receiving -// node). The pseudoTarget, if provided, will be used to override the pub key -// of the destination node in the path. -func (b *BlindedPayment) toRouteHints( - pseudoTarget fn.Option[*btcec.PublicKey]) (RouteHints, error) { - +// node). +func (b *BlindedPayment) toRouteHints() (RouteHints, error) { // If we just have a single hop in our blinded route, it just contains // an introduction node (this is a valid path according to the spec). // Since we have the un-blinded node ID for the introduction node, we @@ -393,16 +461,77 @@ func (b *BlindedPayment) toRouteHints( hints[fromNode] = []AdditionalEdge{lastEdge} } - pseudoTarget.WhenSome(func(key *btcec.PublicKey) { - // For the very last hop on the path, switch out the ToNodePub - // for the pseudo target pub key. - lastEdge.policy.ToNodePubKey = func() route.Vertex { - return route.NewVertex(key) + return hints, nil +} + +// deepCopy returns a deep copy of the BlindedPayment. +func (b *BlindedPayment) deepCopy() *BlindedPayment { + if b == nil { + return nil + } + + cpyPayment := &BlindedPayment{ + BaseFee: b.BaseFee, + ProportionalFeeRate: b.ProportionalFeeRate, + CltvExpiryDelta: b.CltvExpiryDelta, + HtlcMinimum: b.HtlcMinimum, + HtlcMaximum: b.HtlcMaximum, + } + + // Deep copy the BlindedPath if it exists + if b.BlindedPath != nil { + cpyPayment.BlindedPath = &sphinx.BlindedPath{ + BlindedHops: make([]*sphinx.BlindedHopInfo, + len(b.BlindedPath.BlindedHops)), } - // Then override the final hint with this updated edge. - hints[fromNode] = []AdditionalEdge{lastEdge} - }) + if b.BlindedPath.IntroductionPoint != nil { + cpyPayment.BlindedPath.IntroductionPoint = + copyPublicKey(b.BlindedPath.IntroductionPoint) + } - return hints, nil + if b.BlindedPath.BlindingPoint != nil { + cpyPayment.BlindedPath.BlindingPoint = + copyPublicKey(b.BlindedPath.BlindingPoint) + } + + // Copy each blinded hop info. + for i, hop := range b.BlindedPath.BlindedHops { + if hop == nil { + continue + } + + cpyHop := &sphinx.BlindedHopInfo{ + CipherText: hop.CipherText, + } + + if hop.BlindedNodePub != nil { + cpyHop.BlindedNodePub = + copyPublicKey(hop.BlindedNodePub) + } + + cpyHop.CipherText = make([]byte, len(hop.CipherText)) + copy(cpyHop.CipherText, hop.CipherText) + + cpyPayment.BlindedPath.BlindedHops[i] = cpyHop + } + } + + // Deep copy the Features if they exist + if b.Features != nil { + cpyPayment.Features = b.Features.Clone() + } + + return cpyPayment +} + +// copyPublicKey makes a deep copy of a public key. +// +// TODO(ziggie): Remove this function if this is available in the btcec library. +func copyPublicKey(pk *btcec.PublicKey) *btcec.PublicKey { + var result secp256k1.JacobianPoint + pk.AsJacobian(&result) + result.ToAffine() + + return btcec.NewPublicKey(&result.X, &result.Y) } diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 950cb02107..1fabc10c22 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -2,12 +2,12 @@ package routing import ( "bytes" + "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb/models" - "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -129,7 +129,7 @@ func TestBlindedPaymentToHints(t *testing.T) { HtlcMaximum: htlcMax, Features: features, } - hints, err := blindedPayment.toRouteHints(fn.None[*btcec.PublicKey]()) + hints, err := blindedPayment.toRouteHints() require.NoError(t, err) require.Nil(t, hints) @@ -184,7 +184,7 @@ func TestBlindedPaymentToHints(t *testing.T) { }, } - actual, err := blindedPayment.toRouteHints(fn.None[*btcec.PublicKey]()) + actual, err := blindedPayment.toRouteHints() require.NoError(t, err) require.Equal(t, len(expected), len(actual)) @@ -218,3 +218,63 @@ func TestBlindedPaymentToHints(t *testing.T) { require.Equal(t, expectedHint[0], actualHint[0]) } } + +// TestBlindedPaymentDeepCopy tests the deep copy method of the BLindedPayment +// struct. +// +// TODO(ziggie): Make this a property test instead. +func TestBlindedPaymentDeepCopy(t *testing.T) { + _, pkBlind1 := btcec.PrivKeyFromBytes([]byte{1}) + _, blindingPoint := btcec.PrivKeyFromBytes([]byte{2}) + _, pkBlind2 := btcec.PrivKeyFromBytes([]byte{3}) + + // Create a test BlindedPayment with non-nil fields + original := &BlindedPayment{ + BaseFee: 1000, + ProportionalFeeRate: 2000, + CltvExpiryDelta: 144, + HtlcMinimum: 1000, + HtlcMaximum: 1000000, + Features: lnwire.NewFeatureVector(nil, nil), + BlindedPath: &sphinx.BlindedPath{ + IntroductionPoint: pkBlind1, + BlindingPoint: blindingPoint, + BlindedHops: []*sphinx.BlindedHopInfo{ + { + BlindedNodePub: pkBlind2, + CipherText: []byte("test cipher"), + }, + }, + }, + } + + // Make a deep copy + cpyPayment := original.deepCopy() + + // Test 1: Verify the copy is not the same pointer + if cpyPayment == original { + t.Fatal("deepCopy returned same pointer") + } + + // Verify all fields are equal + if !reflect.DeepEqual(original, cpyPayment) { + t.Fatal("copy is not equal to original") + } + + // Modify the copy and verify it doesn't affect the original + cpyPayment.BaseFee = 2000 + cpyPayment.BlindedPath.BlindedHops[0].CipherText = []byte("modified") + + require.NotEqual(t, original.BaseFee, cpyPayment.BaseFee) + + require.NotEqual( + t, + original.BlindedPath.BlindedHops[0].CipherText, + cpyPayment.BlindedPath.BlindedHops[0].CipherText, + ) + + // Verify nil handling. + var nilPayment *BlindedPayment + nilCopy := nilPayment.deepCopy() + require.Nil(t, nilCopy) +} diff --git a/routing/mock_test.go b/routing/mock_test.go index 99d56c68bd..f604b777a4 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -107,7 +107,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) func (m *mockPaymentSessionSourceOld) NewPaymentSession( _ *LightningPayment, _ fn.Option[tlv.Blob], - _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + _ fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) { return &mockPaymentSessionOld{ routes: m.routes, @@ -635,7 +635,8 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( payment *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + tlvShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { args := m.Called(payment, firstHopBlob, tlvShaper) return args.Get(0).(PaymentSession), args.Error(1) @@ -895,6 +896,19 @@ func (m *mockLink) Bandwidth() lnwire.MilliSatoshi { return m.bandwidth } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (m *mockLink) AuxBandwidth(lnwire.MilliSatoshi, lnwire.ShortChannelID, + fn.Option[tlv.Blob], + htlcswitch.AuxTrafficShaper) fn.Result[htlcswitch.OptionalBandwidth] { + + return fn.Ok[htlcswitch.OptionalBandwidth]( + fn.None[lnwire.MilliSatoshi](), + ) +} + // EligibleToForward returns the mock's configured eligibility. func (m *mockLink) EligibleToForward() bool { return !m.ineligible diff --git a/routing/pathfind.go b/routing/pathfind.go index 43eae71036..300613b1fe 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -158,6 +158,32 @@ func newRoute(sourceVertex route.Vertex, ) pathLength := len(pathEdges) + + // When paying to a blinded route we might have appended a dummy hop at + // the end to make MPP payments possible via all paths of the blinded + // route set. We always append a dummy hop when the internal pathfiner + // looks for a route to a blinded path which is at least one hop long + // (excluding the introduction point). We add this dummy hop so that + // we search for a universal target but also respect potential mc + // entries which might already be present for a particular blinded path. + // However when constructing the Sphinx packet we need to remove this + // dummy hop again which we do here. + // + // NOTE: The path length is always at least 1 because there must be one + // edge from the source to the destination. However we check for > 0 + // just for robustness here. + if blindedPathSet != nil && pathLength > 0 { + finalBlindedPubKey := pathEdges[pathLength-1].policy. + ToNodePubKey() + + if IsBlindedRouteNUMSTargetKey(finalBlindedPubKey[:]) { + // If the last hop is the NUMS key for blinded paths, we + // remove the dummy hop from the route. + pathEdges = pathEdges[:pathLength-1] + pathLength-- + } + } + for i := pathLength - 1; i >= 0; i-- { // Now we'll start to calculate the items within the per-hop // payload for the hop this edge is leading to. @@ -319,10 +345,6 @@ func newRoute(sourceVertex route.Vertex, dataIndex = 0 blindedPath = blindedPayment.BlindedPath - numHops = len(blindedPath.BlindedHops) - realFinal = blindedPath.BlindedHops[numHops-1]. - BlindedNodePub - introVertex = route.NewVertex( blindedPath.IntroductionPoint, ) @@ -350,11 +372,6 @@ func newRoute(sourceVertex route.Vertex, if i != len(hops)-1 { hop.AmtToForward = 0 hop.OutgoingTimeLock = 0 - } else { - // For the final hop, we swap out the pub key - // bytes to the original destination node pub - // key for that payment path. - hop.PubKeyBytes = route.NewVertex(realFinal) } dataIndex++ @@ -683,7 +700,10 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // The payload size of the final hop differ from intermediate hops // and depends on whether the destination is blinded or not. - lastHopPayloadSize := lastHopPayloadSize(r, finalHtlcExpiry, amt) + lastHopPayloadSize, err := lastHopPayloadSize(r, finalHtlcExpiry, amt) + if err != nil { + return nil, 0, err + } // We can't always assume that the end destination is publicly // advertised to the network so we'll manually include the target node. @@ -901,6 +921,13 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // included. If we are coming from the source hop, the payload // size is zero, because the original htlc isn't in the onion // blob. + // + // NOTE: For blinded paths with the NUMS key as the last hop, + // the payload size accounts for this dummy hop which is of + // the same size as the real last hop. So we account for a + // bigger size than the route is however we accept this + // little inaccuracy here because we are over estimating by + // 1 hop. var payloadSize uint64 if fromVertex != source { // In case the unifiedEdge does not have a payload size @@ -1409,11 +1436,15 @@ func getProbabilityBasedDist(weight int64, probability float64, // It depends on the tlv types which are present and also whether the hop is // part of a blinded route or not. func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, - amount lnwire.MilliSatoshi) uint64 { + amount lnwire.MilliSatoshi) (uint64, error) { if r.BlindedPaymentPathSet != nil { - paymentPath := r.BlindedPaymentPathSet. + paymentPath, err := r.BlindedPaymentPathSet. LargestLastHopPayloadPath() + if err != nil { + return 0, err + } + blindedPath := paymentPath.BlindedPath.BlindedHops blindedPoint := paymentPath.BlindedPath.BlindingPoint @@ -1428,7 +1459,7 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, } // The final hop does not have a short chanID set. - return finalHop.PayloadSize(0) + return finalHop.PayloadSize(0), nil } var mpp *record.MPP @@ -1454,7 +1485,7 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, } // The final hop does not have a short chanID set. - return finalHop.PayloadSize(0) + return finalHop.PayloadSize(0), nil } // overflowSafeAdd adds two MilliSatoshi values and returns the result. If an diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 72f71600dd..964b0b88a8 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -768,6 +768,9 @@ func TestPathFinding(t *testing.T) { }, { name: "path finding with additional edges", fn: runPathFindingWithAdditionalEdges, + }, { + name: "path finding with duplicate blinded hop", + fn: runPathFindingWithBlindedPathDuplicateHop, }, { name: "path finding with redundant additional edges", fn: runPathFindingWithRedundantAdditionalEdges, @@ -1268,6 +1271,107 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { assertExpectedPath(t, graph.aliasMap, path, "songoku", "doge") } +// runPathFindingWithBlindedPathDuplicateHop tests that in case a blinded path +// has duplicate hops that the path finding algorithm does not fail or behave +// incorrectly. This can happen because the creator of the blinded path can +// specify the same hop multiple times and this will only be detected at the +// forwarding nodes, so it is important that we can handle this case. +func runPathFindingWithBlindedPathDuplicateHop(t *testing.T, useCache bool) { + graph, err := parseTestGraph(t, useCache, basicGraphFilePath) + require.NoError(t, err, "unable to create graph") + + sourceNode, err := graph.graph.SourceNode() + require.NoError(t, err, "unable to fetch source node") + + paymentAmt := lnwire.NewMSatFromSatoshis(100) + + songokuPubKeyBytes := graph.aliasMap["songoku"] + songokuPubKey, err := btcec.ParsePubKey(songokuPubKeyBytes[:]) + require.NoError(t, err, "unable to parse public key from bytes") + + _, pkb1 := btcec.PrivKeyFromBytes([]byte{2}) + _, pkb2 := btcec.PrivKeyFromBytes([]byte{3}) + _, blindedPoint := btcec.PrivKeyFromBytes([]byte{5}) + + sizeEncryptedData := 100 + cipherText := bytes.Repeat( + []byte{1}, sizeEncryptedData, + ) + + vb1 := route.NewVertex(pkb1) + vb2 := route.NewVertex(pkb2) + + // Payments to blinded paths always pay to the NUMS target key. + dummyTarget := route.NewVertex(&BlindedPathNUMSKey) + + graph.aliasMap["pkb1"] = vb1 + graph.aliasMap["pkb2"] = vb2 + graph.aliasMap["dummyTarget"] = dummyTarget + + // Create a blinded payment with duplicate hops and make sure the + // path finding algorithm can cope with that. We add blinded hop 2 + // 3 times. The path finding algorithm should create a path with a + // single hop to pkb2 (the first entry). + blindedPayment := &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + IntroductionPoint: songokuPubKey, + BlindingPoint: blindedPoint, + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: cipherText, + }, + { + BlindedNodePub: pkb2, + CipherText: cipherText, + }, + { + BlindedNodePub: pkb1, + CipherText: cipherText, + }, + { + BlindedNodePub: pkb2, + CipherText: cipherText, + }, + { + BlindedNodePub: &BlindedPathNUMSKey, + CipherText: cipherText, + }, + { + BlindedNodePub: pkb2, + CipherText: cipherText, + }, + }, + }, + HtlcMinimum: 1, + HtlcMaximum: 100_000_000, + CltvExpiryDelta: 140, + } + + blindedPath, err := blindedPayment.toRouteHints() + require.NoError(t, err) + + find := func(r *RestrictParams) ( + []*unifiedEdge, error) { + + return dbFindPath( + graph.graph, blindedPath, &mockBandwidthHints{}, + r, testPathFindingConfig, + sourceNode.PubKeyBytes, dummyTarget, paymentAmt, + 0, 0, + ) + } + + // We should now be able to find a path however not the chained path + // of the blinded hops. + path, err := find(noRestrictions) + require.NoError(t, err, "unable to create route to blinded path") + + // The path should represent the following hops: + // source node -> songoku -> pkb2 -> dummyTarget + assertExpectedPath(t, graph.aliasMap, path, "songoku", "pkb2", + "dummyTarget") +} + // runPathFindingWithRedundantAdditionalEdges asserts that we are able to find // paths to nodes ignoring additional edges that are already known by self node. func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { @@ -3287,9 +3391,7 @@ func TestBlindedRouteConstruction(t *testing.T) { // that make up the graph we'll give to route construction. The hints // map is keyed by source node, so we can retrieve our blinded edges // accordingly. - blindedEdges, err := blindedPayment.toRouteHints( - fn.None[*btcec.PublicKey](), - ) + blindedEdges, err := blindedPayment.toRouteHints() require.NoError(t, err) carolDaveEdge := blindedEdges[carolVertex][0] @@ -3418,32 +3520,48 @@ func TestLastHopPayloadSize(t *testing.T) { customRecords = map[uint64][]byte{ record.CustomTypeStart: {1, 2, 3}, } - sizeEncryptedData = 100 - encrypedData = bytes.Repeat( - []byte{1}, sizeEncryptedData, + + encrypedDataSmall = bytes.Repeat( + []byte{1}, 5, ) - _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) - paymentAddr = &[32]byte{1} - ampOptions = &Options{} - amtToForward = lnwire.MilliSatoshi(10000) - finalHopExpiry int32 = 144 + encrypedDataLarge = bytes.Repeat( + []byte{1}, 100, + ) + _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) + paymentAddr = &[32]byte{1} + ampOptions = &Options{} + amtToForward = lnwire.MilliSatoshi(10000) + emptyEncryptedData = []byte{} + finalHopExpiry int32 = 144 oneHopPath = &sphinx.BlindedPath{ BlindedHops: []*sphinx.BlindedHopInfo{ { - CipherText: encrypedData, + CipherText: emptyEncryptedData, + }, + }, + BlindingPoint: blindedPoint, + } + + twoHopPathSmallHopSize = &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: encrypedDataLarge, + }, + { + CipherText: encrypedDataLarge, }, }, BlindingPoint: blindedPoint, } - twoHopPath = &sphinx.BlindedPath{ + twoHopPathLargeHopSize = &sphinx.BlindedPath{ BlindedHops: []*sphinx.BlindedHopInfo{ { - CipherText: encrypedData, + CipherText: encrypedDataSmall, }, { - CipherText: encrypedData, + CipherText: encrypedDataSmall, }, }, BlindingPoint: blindedPoint, @@ -3456,15 +3574,19 @@ func TestLastHopPayloadSize(t *testing.T) { require.NoError(t, err) twoHopBlindedPayment, err := NewBlindedPaymentPathSet( - []*BlindedPayment{{BlindedPath: twoHopPath}}, + []*BlindedPayment{ + {BlindedPath: twoHopPathLargeHopSize}, + {BlindedPath: twoHopPathSmallHopSize}, + }, ) require.NoError(t, err) testCases := []struct { - name string - restrictions *RestrictParams - finalHopExpiry int32 - amount lnwire.MilliSatoshi + name string + restrictions *RestrictParams + finalHopExpiry int32 + amount lnwire.MilliSatoshi + expectedEncryptedData []byte }{ { name: "Non blinded final hop", @@ -3482,16 +3604,18 @@ func TestLastHopPayloadSize(t *testing.T) { restrictions: &RestrictParams{ BlindedPaymentPathSet: oneHopBlindedPayment, }, - amount: amtToForward, - finalHopExpiry: finalHopExpiry, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + expectedEncryptedData: emptyEncryptedData, }, { name: "Blinded final hop of a two hop payment", restrictions: &RestrictParams{ BlindedPaymentPathSet: twoHopBlindedPayment, }, - amount: amtToForward, - finalHopExpiry: finalHopExpiry, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + expectedEncryptedData: encrypedDataLarge, }, } @@ -3515,16 +3639,23 @@ func TestLastHopPayloadSize(t *testing.T) { var finalHop route.Hop if tc.restrictions.BlindedPaymentPathSet != nil { - path := tc.restrictions.BlindedPaymentPathSet. - LargestLastHopPayloadPath() + bPSet := tc.restrictions.BlindedPaymentPathSet + path, err := bPSet.LargestLastHopPayloadPath() + require.NotNil(t, path) + + require.NoError(t, err) + blindedPath := path.BlindedPath.BlindedHops blindedPoint := path.BlindedPath.BlindingPoint + lastHop := blindedPath[len(blindedPath)-1] + require.Equal(t, lastHop.CipherText, + tc.expectedEncryptedData) //nolint:lll finalHop = route.Hop{ AmtToForward: tc.amount, OutgoingTimeLock: uint32(tc.finalHopExpiry), - EncryptedData: blindedPath[len(blindedPath)-1].CipherText, + EncryptedData: lastHop.CipherText, } if len(blindedPath) == 1 { finalHop.BlindingPoint = blindedPoint @@ -3544,11 +3675,11 @@ func TestLastHopPayloadSize(t *testing.T) { payLoad, err := createHopPayload(finalHop, 0, true) require.NoErrorf(t, err, "failed to create hop payload") - expectedPayloadSize := lastHopPayloadSize( + expectedPayloadSize, err := lastHopPayloadSize( tc.restrictions, tc.finalHopExpiry, tc.amount, ) - + require.NoError(t, err) require.Equal( t, expectedPayloadSize, uint64(payLoad.NumBytes()), diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 43e646e192..5e72beb87d 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -761,7 +761,8 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // and apply its side effects to the UpdateAddHTLC message. result, err := fn.MapOptionZ( p.router.cfg.TrafficShaper, - func(ts TlvTrafficShaper) fn.Result[extraDataRequest] { + //nolint:ll + func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] { newAmt, newRecords, err := ts.ProduceHtlcExtraData( rt.TotalAmount, p.firstHopCustomRecords, ) @@ -774,7 +775,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { return fn.Err[extraDataRequest](err) } - log.Debugf("TLV traffic shaper returned custom "+ + log.Debugf("Aux traffic shaper returned custom "+ "records %v and amount %d msat for HTLC", spew.Sdump(newRecords), newAmt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad58..d566eb9413 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -30,7 +30,7 @@ func createTestPaymentLifecycle() *paymentLifecycle { quitChan := make(chan struct{}) rt := &ChannelRouter{ cfg: &Config{ - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, @@ -83,7 +83,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { Payer: mockPayer, Clock: mockClock, MissionControl: mockMissionControl, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index c89d6a8e52..b03f50b153 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -5,6 +5,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -53,7 +54,8 @@ type SessionSource struct { // payment's destination. func (m *SessionSource) NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 4118286e64..b998968d01 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -93,8 +93,27 @@ func interpretResult(rt *route.Route, success bool, failureSrcIdx *int, // processSuccess processes a successful payment attempt. func (i *interpretedResult) processSuccess(route *route.Route) { - // For successes, all nodes must have acted in the right way. Therefore - // we mark all of them with a success result. + // For successes, all nodes must have acted in the right way. + // Therefore we mark all of them with a success result. However we need + // to handle the blinded route part separately because for intermediate + // blinded nodes the amount field is set to zero so we use the receiver + // amount. + introIdx, isBlinded := introductionPointIndex(route) + if isBlinded { + // Report success for all the pairs until the introduction + // point. + i.successPairRange(route, 0, introIdx-1) + + // Handle the blinded route part. + // + // NOTE: The introIdx index here does describe the node after + // the introduction point. + i.markBlindedRouteSuccess(route, introIdx) + + return + } + + // Mark nodes as successful in the non-blinded case of the payment. i.successPairRange(route, 0, len(route.Hops)-1) } @@ -505,13 +524,22 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate( if introIdx == len(route.Hops)-1 { i.finalFailureReason = &reasonError } else { - // If there are other hops between the recipient and - // introduction node, then we just penalize the last - // hop in the blinded route to minimize the storage of - // results for ephemeral keys. - i.failPairBalance( - route, len(route.Hops)-1, - ) + // We penalize the final hop of the blinded route which + // is sufficient to not reuse this route again and is + // also more memory efficient because the other hops + // of the blinded path are ephemeral and will only be + // used in conjunction with the final hop. Moreover we + // don't want to punish the introduction node because + // the blinded failure does not necessarily mean that + // the introduction node was at fault. + // + // TODO(ziggie): Make sure we only keep mc data for + // blinded paths, in both the success and failure case, + // in memory during the time of the payment and remove + // it afterwards. Blinded paths and their blinded hop + // keys are always changing per blinded route so there + // is no point in persisting this data. + i.failBlindedRoute(route) } // In all other cases, we penalize the reporting node. These are all @@ -624,6 +652,43 @@ func (i *interpretedResult) successPairRange( } } +// failBlindedRoute marks a blinded route as failed for the specific amount to +// send by only punishing the last pair. +func (i *interpretedResult) failBlindedRoute(rt *route.Route) { + // We fail the last pair of the route, in order to fail the complete + // blinded route. This is because the combination of ephemeral pubkeys + // is unique to the route. We fail the last pair in order to not punish + // the introduction node, since we don't want to disincentivize them + // from providing that service. + pair, _ := getPair(rt, len(rt.Hops)-1) + + // Since all the hops along a blinded path don't have any amount set, we + // extract the minimal amount to punish from the value that is tried to + // be sent to the receiver. + amt := rt.Hops[len(rt.Hops)-1].AmtToForward + + i.pairResults[pair] = failPairResult(amt) +} + +// markBlindedRouteSuccess marks the hops of the blinded route AFTER the +// introduction node as successful. +// +// NOTE: The introIdx must be the index of the first hop of the blinded route +// AFTER the introduction node. +func (i *interpretedResult) markBlindedRouteSuccess(rt *route.Route, + introIdx int) { + + // For blinded hops we do not have the forwarding amount so we take the + // minimal amount which went through the route by looking at the last + // hop. + successAmt := rt.Hops[len(rt.Hops)-1].AmtToForward + for idx := introIdx; idx < len(rt.Hops); idx++ { + pair, _ := getPair(rt, idx) + + i.pairResults[pair] = successPairResult(successAmt) + } +} + // getPair returns a node pair from the route and the amount passed between that // pair. func getPair(rt *route.Route, channelIdx int) (DirectedNodePair, diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index bf7d6d3edd..0b8a2e2629 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -64,14 +64,27 @@ var ( SourcePubKey: hops[0], TotalAmount: 100, Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, { - PubKeyBytes: hops[2], - AmtToForward: 95, - BlindingPoint: blindingPoint, + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + // Intermediate blinded hops don't have an + // amount set. + AmtToForward: 0, + BlindingPoint: genTestPubKey(), + }, + { + PubKeyBytes: hops[3], + // Intermediate blinded hops don't have an + // amount set. + AmtToForward: 0, + }, + { + PubKeyBytes: hops[4], + AmtToForward: 77, }, - {PubKeyBytes: hops[3], AmtToForward: 88}, - {PubKeyBytes: hops[4], AmtToForward: 77}, }, } @@ -81,13 +94,21 @@ var ( SourcePubKey: hops[0], TotalAmount: 100, Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, { - PubKeyBytes: hops[2], - AmtToForward: 95, - BlindingPoint: blindingPoint, + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + // Intermediate blinded hops don't have an + // amount set. + AmtToForward: 0, + BlindingPoint: genTestPubKey(), + }, + { + PubKeyBytes: hops[3], + AmtToForward: 88, }, - {PubKeyBytes: hops[3], AmtToForward: 88}, }, } @@ -98,12 +119,22 @@ var ( TotalAmount: 100, Hops: []*route.Hop{ { - PubKeyBytes: hops[1], - AmtToForward: 90, - BlindingPoint: blindingPoint, + PubKeyBytes: hops[1], + // Intermediate blinded hops don't have an + // amount set. + AmtToForward: 0, + BlindingPoint: genTestPubKey(), + }, + { + PubKeyBytes: hops[2], + // Intermediate blinded hops don't have an + // amount set. + AmtToForward: 0, + }, + { + PubKeyBytes: hops[3], + AmtToForward: 58, }, - {PubKeyBytes: hops[2], AmtToForward: 75}, - {PubKeyBytes: hops[3], AmtToForward: 58}, }, } @@ -113,7 +144,10 @@ var ( SourcePubKey: hops[0], TotalAmount: 100, Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 95}, + { + PubKeyBytes: hops[1], + AmtToForward: 95, + }, { PubKeyBytes: hops[2], AmtToForward: 90, @@ -123,6 +157,12 @@ var ( } ) +func genTestPubKey() *btcec.PublicKey { + key, _ := btcec.NewPrivateKey() + + return key.PubKey() +} + func getTestPair(from, to int) DirectedNodePair { return NewDirectedNodePair(hops[from], hops[to]) } @@ -494,7 +534,12 @@ var resultTestCases = []resultTestCase{ pairResults: map[DirectedNodePair]pairResult{ getTestPair(0, 1): successPairResult(100), getTestPair(1, 2): successPairResult(99), - getTestPair(3, 4): failPairResult(88), + + // The amount for the last hop is always the + // receiver amount because the amount to forward + // is always set to 0 for intermediate blinded + // hops. + getTestPair(3, 4): failPairResult(77), }, }, }, @@ -509,7 +554,12 @@ var resultTestCases = []resultTestCase{ expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ getTestPair(0, 1): successPairResult(100), - getTestPair(2, 3): failPairResult(75), + + // The amount for the last hop is always the + // receiver amount because the amount to forward + // is always set to 0 for intermediate blinded + // hops. + getTestPair(2, 3): failPairResult(58), }, }, }, @@ -624,6 +674,25 @@ var resultTestCases = []resultTestCase{ finalFailureReason: &reasonError, }, }, + // Test a multi-hop blinded route and that in a success case the amounts + // for the blinded route part are correctly set to the receiver amount. + { + name: "blinded multi-hop success", + route: &blindedMultiToIntroduction, + success: true, + expectedResult: &interpretedResult{ + pairResults: map[DirectedNodePair]pairResult{ + getTestPair(0, 1): successPairResult(100), + + // For the route blinded part of the route the + // success amount is determined by the receiver + // amount because the intermediate blinded hops + // set the forwarded amount to 0. + getTestPair(1, 2): successPairResult(58), + getTestPair(2, 3): successPairResult(58), + }, + }, + }, } // TestResultInterpretation executes a list of test cases that test the result diff --git a/routing/router.go b/routing/router.go index 04096fa67f..aa31933866 100644 --- a/routing/router.go +++ b/routing/router.go @@ -157,7 +157,7 @@ type PaymentSessionSource interface { // finding a path to the payment's destination. NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, + ts fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) // NewPaymentSessionEmpty creates a new paymentSession instance that is @@ -297,7 +297,7 @@ type Config struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // EdgeLocator is a struct used to identify a specific edge. diff --git a/routing/router_test.go b/routing/router_test.go index 53c49f1cfc..eab57236bc 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -164,7 +164,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, Clock: clock.NewTestClock(time.Unix(1, 0)), ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }) @@ -2194,8 +2194,10 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Register mockers with the expected method calls. @@ -2279,8 +2281,10 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Expect an error to be returned. @@ -2335,8 +2339,10 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2419,8 +2425,10 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2507,8 +2515,10 @@ func TestSendToRouteTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. diff --git a/server.go b/server.go index 3dd8040e0a..8e7d225d5c 100644 --- a/server.go +++ b/server.go @@ -4126,6 +4126,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, MsgRouter: s.implCfg.MsgRouter, AuxChanCloser: s.implCfg.AuxChanCloser, AuxResolver: s.implCfg.AuxContractResolver, + AuxTrafficShaper: s.implCfg.TrafficShaper, } copy(pCfg.PubKeyBytes[:], peerAddr.IdentityKey.SerializeCompressed())