Skip to content

Commit

Permalink
Ensure p2p protocol matches new Starknet spec
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciechos committed Jan 15, 2025
1 parent 06717e2 commit 7a7d296
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 23 deletions.
16 changes: 8 additions & 8 deletions p2p/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
}
}

p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork)
if err != nil {
return nil, err
}
Expand All @@ -160,9 +160,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
return s, nil
}

func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, network *utils.Network) (*dht.IpfsDHT, error) {
return dht.New(context.Background(), p2phost,
dht.ProtocolPrefix(p2pSync.Prefix),
dht.ProtocolPrefix(p2pSync.DHTPrefixPID(network)),
dht.BootstrapPeers(addrInfos...),
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
dht.Mode(dht.ModeServer),
Expand Down Expand Up @@ -250,11 +250,11 @@ func (s *Service) Run(ctx context.Context) error {
}

func (s *Service) setProtocolHandlers() {
s.SetProtocolHandler(p2pSync.HeadersPID(), s.handler.HeadersHandler)
s.SetProtocolHandler(p2pSync.EventsPID(), s.handler.EventsHandler)
s.SetProtocolHandler(p2pSync.TransactionsPID(), s.handler.TransactionsHandler)
s.SetProtocolHandler(p2pSync.ClassesPID(), s.handler.ClassesHandler)
s.SetProtocolHandler(p2pSync.StateDiffPID(), s.handler.StateDiffHandler)
s.SetProtocolHandler(p2pSync.HeadersPID(s.network), s.handler.HeadersHandler)
s.SetProtocolHandler(p2pSync.EventsPID(s.network), s.handler.EventsHandler)
s.SetProtocolHandler(p2pSync.TransactionsPID(s.network), s.handler.TransactionsHandler)
s.SetProtocolHandler(p2pSync.ClassesPID(s.network), s.handler.ClassesHandler)
s.SetProtocolHandler(p2pSync.StateDiffPID(s.network), s.handler.StateDiffHandler)
}

func (s *Service) callAndLogErr(f func() error, msg string) {
Expand Down
36 changes: 36 additions & 0 deletions p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/NethermindEth/juno/p2p"
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
)
require.NoError(t, err)
}

func TestMakeDHTProtocolName(t *testing.T) {
net, err := mocknet.FullMeshLinked(1)
require.NoError(t, err)
testHost := net.Hosts()[0]

testCases := []struct {
name string
network *utils.Network
expected string
}{
{
name: "sepolia network",
network: &utils.Sepolia,
expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0",
},
{
name: "mainnet network",
network: &utils.Mainnet,
expected: "/starknet/SN_MAIN/sync/kad/1.0.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dht, err := p2p.MakeDHT(testHost, nil, tc.network)
require.NoError(t, err)

protocols := dht.Host().Mux().Protocols()
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
})
}
}
10 changes: 5 additions & 5 deletions p2p/sync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ func (c *Client) RequestBlockHeaders(
ctx context.Context, req *gen.BlockHeadersRequest,
) (iter.Seq[*gen.BlockHeadersResponse], error) {
return requestAndReceiveStream[*gen.BlockHeadersRequest, *gen.BlockHeadersResponse](
ctx, c.newStream, HeadersPID(), req, c.log)
ctx, c.newStream, HeadersPID(c.network), req, c.log)
}

func (c *Client) RequestEvents(ctx context.Context, req *gen.EventsRequest) (iter.Seq[*gen.EventsResponse], error) {
return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log)
return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(c.network), req, c.log)
}

func (c *Client) RequestClasses(ctx context.Context, req *gen.ClassesRequest) (iter.Seq[*gen.ClassesResponse], error) {
return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log)
return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(c.network), req, c.log)
}

func (c *Client) RequestStateDiffs(ctx context.Context, req *gen.StateDiffsRequest) (iter.Seq[*gen.StateDiffsResponse], error) {
return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log)
return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(c.network), req, c.log)
}

func (c *Client) RequestTransactions(ctx context.Context, req *gen.TransactionsRequest) (iter.Seq[*gen.TransactionsResponse], error) {
return requestAndReceiveStream[*gen.TransactionsRequest, *gen.TransactionsResponse](
ctx, c.newStream, TransactionsPID(), req, c.log)
ctx, c.newStream, TransactionsPID(c.network), req, c.log)
}
25 changes: 15 additions & 10 deletions p2p/sync/ids.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
package sync

import (
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/protocol"
)

const Prefix = "/starknet"

func HeadersPID() protocol.ID {
return Prefix + "/headers/0.1.0-rc.0"
func HeadersPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/headers/0.1.0-rc.0")
}

func EventsPID() protocol.ID {
return Prefix + "/events/0.1.0-rc.0"
func EventsPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/events/0.1.0-rc.0")
}

func TransactionsPID() protocol.ID {
return Prefix + "/transactions/0.1.0-rc.0"
func TransactionsPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/transactions/0.1.0-rc.0")
}

func ClassesPID() protocol.ID {
return Prefix + "/classes/0.1.0-rc.0"
func ClassesPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/classes/0.1.0-rc.0")
}

func StateDiffPID() protocol.ID {
return Prefix + "/state_diffs/0.1.0-rc.0"
func StateDiffPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/state_diffs/0.1.0-rc.0")
}

func DHTPrefixPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync")
}
67 changes: 67 additions & 0 deletions p2p/sync/ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package sync

import (
"testing"

"github.com/NethermindEth/juno/utils"
"github.com/stretchr/testify/assert"
)

func TestProtocolIDs(t *testing.T) {
testCases := []struct {
name string
network *utils.Network
pidFunc func(*utils.Network) string
expected string
}{
{
name: "HeadersPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) },
expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0",
},
{
name: "EventsPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(EventsPID(n)) },
expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0",
},
{
name: "TransactionsPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(TransactionsPID(n)) },
expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0",
},
{
name: "ClassesPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(ClassesPID(n)) },
expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0",
},
{
name: "StateDiffPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(StateDiffPID(n)) },
expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0",
},
{
name: "DHTPrefixPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(DHTPrefixPID(n)) },
expected: "/starknet/SN_MAIN/sync",
},
{
name: "HeadersPID with SN_SEPOLIA",
network: &utils.Sepolia,
pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) },
expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.pidFunc(tc.network)
assert.Equal(t, tc.expected, result)
})
}
}

0 comments on commit 7a7d296

Please sign in to comment.