From 7a7d2962800db8ec7d20dce13cbd196d3501ef50 Mon Sep 17 00:00:00 2001 From: wojo Date: Wed, 15 Jan 2025 17:49:44 +0700 Subject: [PATCH] Ensure p2p protocol matches new Starknet spec --- p2p/p2p.go | 16 +++++------ p2p/p2p_test.go | 36 ++++++++++++++++++++++++ p2p/sync/client.go | 10 +++---- p2p/sync/ids.go | 25 ++++++++++------- p2p/sync/ids_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 23 deletions(-) create mode 100644 p2p/sync/ids_test.go diff --git a/p2p/p2p.go b/p2p/p2p.go index ddb1ed955d..bd54ad4c16 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -139,7 +139,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai } } - p2pdht, err := makeDHT(p2phost, peersAddrInfoS) + p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork) if err != nil { return nil, err } @@ -160,9 +160,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai return s, nil } -func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) { +func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, network *utils.Network) (*dht.IpfsDHT, error) { return dht.New(context.Background(), p2phost, - dht.ProtocolPrefix(p2pSync.Prefix), + dht.ProtocolPrefix(p2pSync.DHTPrefixPID(network)), dht.BootstrapPeers(addrInfos...), dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod), dht.Mode(dht.ModeServer), @@ -250,11 +250,11 @@ func (s *Service) Run(ctx context.Context) error { } func (s *Service) setProtocolHandlers() { - s.SetProtocolHandler(p2pSync.HeadersPID(), s.handler.HeadersHandler) - s.SetProtocolHandler(p2pSync.EventsPID(), s.handler.EventsHandler) - s.SetProtocolHandler(p2pSync.TransactionsPID(), s.handler.TransactionsHandler) - s.SetProtocolHandler(p2pSync.ClassesPID(), s.handler.ClassesHandler) - s.SetProtocolHandler(p2pSync.StateDiffPID(), s.handler.StateDiffHandler) + s.SetProtocolHandler(p2pSync.HeadersPID(s.network), s.handler.HeadersHandler) + s.SetProtocolHandler(p2pSync.EventsPID(s.network), s.handler.EventsHandler) + s.SetProtocolHandler(p2pSync.TransactionsPID(s.network), s.handler.TransactionsHandler) + s.SetProtocolHandler(p2pSync.ClassesPID(s.network), s.handler.ClassesHandler) + s.SetProtocolHandler(p2pSync.StateDiffPID(s.network), s.handler.StateDiffHandler) } func (s *Service) callAndLogErr(f func() error, msg string) { diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 54b19d5900..0e10a4c04c 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -8,7 +8,10 @@ import ( "github.com/NethermindEth/juno/p2p" "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) { ) require.NoError(t, err) } + +func TestMakeDHTProtocolName(t *testing.T) { + net, err := mocknet.FullMeshLinked(1) + require.NoError(t, err) + testHost := net.Hosts()[0] + + testCases := []struct { + name string + network *utils.Network + expected string + }{ + { + name: "sepolia network", + network: &utils.Sepolia, + expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0", + }, + { + name: "mainnet network", + network: &utils.Mainnet, + expected: "/starknet/SN_MAIN/sync/kad/1.0.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dht, err := p2p.MakeDHT(testHost, nil, tc.network) + require.NoError(t, err) + + protocols := dht.Host().Mux().Protocols() + assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols) + }) + } +} diff --git a/p2p/sync/client.go b/p2p/sync/client.go index 5f688a0378..5d4271df82 100644 --- a/p2p/sync/client.go +++ b/p2p/sync/client.go @@ -104,22 +104,22 @@ func (c *Client) RequestBlockHeaders( ctx context.Context, req *gen.BlockHeadersRequest, ) (iter.Seq[*gen.BlockHeadersResponse], error) { return requestAndReceiveStream[*gen.BlockHeadersRequest, *gen.BlockHeadersResponse]( - ctx, c.newStream, HeadersPID(), req, c.log) + ctx, c.newStream, HeadersPID(c.network), req, c.log) } func (c *Client) RequestEvents(ctx context.Context, req *gen.EventsRequest) (iter.Seq[*gen.EventsResponse], error) { - return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log) + return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(c.network), req, c.log) } func (c *Client) RequestClasses(ctx context.Context, req *gen.ClassesRequest) (iter.Seq[*gen.ClassesResponse], error) { - return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log) + return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(c.network), req, c.log) } func (c *Client) RequestStateDiffs(ctx context.Context, req *gen.StateDiffsRequest) (iter.Seq[*gen.StateDiffsResponse], error) { - return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log) + return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(c.network), req, c.log) } func (c *Client) RequestTransactions(ctx context.Context, req *gen.TransactionsRequest) (iter.Seq[*gen.TransactionsResponse], error) { return requestAndReceiveStream[*gen.TransactionsRequest, *gen.TransactionsResponse]( - ctx, c.newStream, TransactionsPID(), req, c.log) + ctx, c.newStream, TransactionsPID(c.network), req, c.log) } diff --git a/p2p/sync/ids.go b/p2p/sync/ids.go index 284875e36b..882247073b 100644 --- a/p2p/sync/ids.go +++ b/p2p/sync/ids.go @@ -1,27 +1,32 @@ package sync import ( + "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/protocol" ) const Prefix = "/starknet" -func HeadersPID() protocol.ID { - return Prefix + "/headers/0.1.0-rc.0" +func HeadersPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/headers/0.1.0-rc.0") } -func EventsPID() protocol.ID { - return Prefix + "/events/0.1.0-rc.0" +func EventsPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/events/0.1.0-rc.0") } -func TransactionsPID() protocol.ID { - return Prefix + "/transactions/0.1.0-rc.0" +func TransactionsPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/transactions/0.1.0-rc.0") } -func ClassesPID() protocol.ID { - return Prefix + "/classes/0.1.0-rc.0" +func ClassesPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/classes/0.1.0-rc.0") } -func StateDiffPID() protocol.ID { - return Prefix + "/state_diffs/0.1.0-rc.0" +func StateDiffPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/state_diffs/0.1.0-rc.0") } + +func DHTPrefixPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync") +} \ No newline at end of file diff --git a/p2p/sync/ids_test.go b/p2p/sync/ids_test.go new file mode 100644 index 0000000000..8ad6ae6d3e --- /dev/null +++ b/p2p/sync/ids_test.go @@ -0,0 +1,67 @@ +package sync + +import ( + "testing" + + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" +) + +func TestProtocolIDs(t *testing.T) { + testCases := []struct { + name string + network *utils.Network + pidFunc func(*utils.Network) string + expected string + }{ + { + name: "HeadersPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) }, + expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0", + }, + { + name: "EventsPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(EventsPID(n)) }, + expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0", + }, + { + name: "TransactionsPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(TransactionsPID(n)) }, + expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0", + }, + { + name: "ClassesPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(ClassesPID(n)) }, + expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0", + }, + { + name: "StateDiffPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(StateDiffPID(n)) }, + expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0", + }, + { + name: "DHTPrefixPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(DHTPrefixPID(n)) }, + expected: "/starknet/SN_MAIN/sync", + }, + { + name: "HeadersPID with SN_SEPOLIA", + network: &utils.Sepolia, + pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) }, + expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.pidFunc(tc.network) + assert.Equal(t, tc.expected, result) + }) + } +}