-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathprudp_server.go
341 lines (279 loc) · 11 KB
/
prudp_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
package nex
import (
"bytes"
"crypto/rand"
"fmt"
"net"
"runtime"
"time"
"github.com/PretendoNetwork/nex-go/v2/constants"
"github.com/lxzan/gws"
)
// PRUDPServer represents a bare-bones PRUDP server
type PRUDPServer struct {
udpSocket *net.UDPConn
websocketServer *WebSocketServer
Endpoints *MutexMap[uint8, *PRUDPEndPoint]
SupportedFunctions uint32
AccessKey string
KerberosTicketVersion int
SessionKeyLength int
FragmentSize int
PRUDPv1ConnectionSignatureKey []byte
LibraryVersions *LibraryVersions
ByteStreamSettings *ByteStreamSettings
PRUDPV0Settings *PRUDPV0Settings
PRUDPV1Settings *PRUDPV1Settings
UseVerboseRMC bool
}
// BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server
func (ps *PRUDPServer) BindPRUDPEndPoint(endpoint *PRUDPEndPoint) {
if ps.Endpoints.Has(endpoint.StreamID) {
logger.Warningf("Tried to bind already existing PRUDPEndPoint %d", endpoint.StreamID)
return
}
endpoint.Server = ps
ps.Endpoints.Set(endpoint.StreamID, endpoint)
}
// Listen is an alias of ListenUDP. Implemented to conform to the EndpointInterface
func (ps *PRUDPServer) Listen(port int) {
ps.ListenUDP(port)
}
// ListenUDP starts a PRUDP server on a given port using a UDP server
func (ps *PRUDPServer) ListenUDP(port int) {
ps.initPRUDPv1ConnectionSignatureKey()
udpAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
if err != nil {
panic(err)
}
socket, err := net.ListenUDP("udp", udpAddress)
if err != nil {
panic(err)
}
ps.udpSocket = socket
quit := make(chan struct{})
for i := 0; i < runtime.NumCPU(); i++ {
go ps.listenDatagram(quit)
}
<-quit
}
func (ps *PRUDPServer) listenDatagram(quit chan struct{}) {
var err error
buffer := make([]byte, 64000)
for err == nil {
var read int
var addr *net.UDPAddr
read, addr, err = ps.udpSocket.ReadFromUDP(buffer)
if err == nil {
packetData := make([]byte, read)
copy(packetData, buffer[:read])
err = ps.handleSocketMessage(packetData, addr, nil)
}
}
quit <- struct{}{}
panic(err)
}
// ListenWebSocket starts a PRUDP server on a given port using a WebSocket server
func (ps *PRUDPServer) ListenWebSocket(port int) {
ps.initPRUDPv1ConnectionSignatureKey()
ps.websocketServer = &WebSocketServer{
prudpServer: ps,
}
ps.websocketServer.listen(port)
}
// ListenWebSocketSecure starts a PRUDP server on a given port using a secure (TLS) WebSocket server
func (ps *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) {
ps.initPRUDPv1ConnectionSignatureKey()
ps.websocketServer = &WebSocketServer{
prudpServer: ps,
}
ps.websocketServer.listenSecure(port, certFile, keyFile)
}
func (ps *PRUDPServer) initPRUDPv1ConnectionSignatureKey() {
// * Ensure the server has a key for PRUDPv1 connection signatures
if len(ps.PRUDPv1ConnectionSignatureKey) != 16 {
ps.PRUDPv1ConnectionSignatureKey = make([]byte, 16)
_, err := rand.Read(ps.PRUDPv1ConnectionSignatureKey)
if err != nil {
panic(err)
}
}
}
func (ps *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error {
readStream := NewByteStreamIn(packetData, ps.LibraryVersions, ps.ByteStreamSettings)
var packets []PRUDPPacketInterface
// * Support any packet type the client sends and respond
// * with that same type. Also keep reading from the stream
// * until no more data is left, to account for multiple
// * packets being sent at once
if ps.websocketServer != nil && packetData[0] == 0x80 {
packets, _ = NewPRUDPPacketsLite(ps, nil, readStream)
} else if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) {
packets, _ = NewPRUDPPacketsV1(ps, nil, readStream)
} else {
packets, _ = NewPRUDPPacketsV0(ps, nil, readStream)
}
for _, packet := range packets {
go ps.processPacket(packet, address, webSocketConnection)
}
return nil
}
func (ps *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) {
if !ps.Endpoints.Has(packet.DestinationVirtualPortStreamID()) {
logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID())
return
}
endpoint, ok := ps.Endpoints.Get(packet.DestinationVirtualPortStreamID())
if !ok {
logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID())
return
}
if packet.DestinationVirtualPortStreamType() != packet.SourceVirtualPortStreamType() {
logger.Warningf("Client %s trying to use non matching destination and source stream types %d and %d", address.String(), packet.DestinationVirtualPortStreamType(), packet.SourceVirtualPortStreamType())
return
}
if packet.DestinationVirtualPortStreamType() > constants.StreamTypeRelay {
logger.Warningf("Client %s trying to use invalid to destination stream type %d", address.String(), packet.DestinationVirtualPortStreamType())
return
}
if packet.SourceVirtualPortStreamType() > constants.StreamTypeRelay {
logger.Warningf("Client %s trying to use invalid to source stream type %d", address.String(), packet.DestinationVirtualPortStreamType())
return
}
sourcePortNumber := packet.SourceVirtualPortStreamID()
invalidSourcePort := false
// * PRUDPLite packets can use port numbers 0-31
// * PRUDPv0 and PRUDPv1 can use port numbers 0-15
if _, ok := packet.(*PRUDPPacketLite); ok && sourcePortNumber > 31 {
invalidSourcePort = true
} else if sourcePortNumber > 15 {
invalidSourcePort = true
}
if invalidSourcePort {
logger.Warningf("Client %s trying to use invalid to source port number %d. Port number too large", address.String(), sourcePortNumber)
return
}
socket := NewSocketConnection(ps, address, webSocketConnection)
endpoint.processPacket(packet, socket)
}
// Send sends the packet to the packets sender
func (ps *PRUDPServer) Send(packet PacketInterface) {
if packet, ok := packet.(PRUDPPacketInterface); ok {
data := packet.Payload()
fragments := int(len(data) / ps.FragmentSize)
var fragmentID uint8 = 1
for i := 0; i <= fragments; i++ {
if len(data) < ps.FragmentSize {
packet.SetPayload(data)
packet.setFragmentID(0)
} else {
packet.SetPayload(data[:ps.FragmentSize])
packet.setFragmentID(fragmentID)
data = data[ps.FragmentSize:]
fragmentID++
}
ps.sendPacket(packet)
// * This delay is here to prevent the server from overloading the client with too many packets.
// * The 16ms (1/60th of a second) value is chosen based on testing with the friends server and is a good balance between
// * Not being too slow and also not dropping any packets because we've overloaded the client. This may be because it
// * roughly matches the framerate that most games target (60fps)
if i < fragments {
time.Sleep(16 * time.Millisecond)
}
}
}
}
func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) {
// * PRUDPServer.Send will send fragments as the same packet,
// * just with different fields. In order to prevent modifying
// * multiple packets at once, due to the same pointer being
// * reused, we must make a copy of the packet being sent
packetCopy := packet.Copy()
connection := packetCopy.Sender().(*PRUDPConnection)
if !packetCopy.HasFlag(constants.PacketFlagAck) && !packetCopy.HasFlag(constants.PacketFlagMultiAck) {
if packetCopy.HasFlag(constants.PacketFlagReliable) {
slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID())
packetCopy.SetSequenceID(slidingWindow.NextOutgoingSequenceID())
} else if packetCopy.Type() == constants.DataPacket {
packetCopy.SetSequenceID(connection.outgoingUnreliableSequenceIDCounter.Next())
} else if packetCopy.Type() == constants.PingPacket {
packetCopy.SetSequenceID(connection.outgoingPingSequenceIDCounter.Next())
connection.lastSentPingTime = time.Now()
} else {
packetCopy.SetSequenceID(0)
}
}
packetCopy.SetSessionID(connection.ServerSessionID)
if packetCopy.Type() == constants.DataPacket && !packetCopy.HasFlag(constants.PacketFlagAck) && !packetCopy.HasFlag(constants.PacketFlagMultiAck) {
if packetCopy.HasFlag(constants.PacketFlagReliable) {
slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID())
payload := packetCopy.Payload()
compressedPayload, err := slidingWindow.streamSettings.CompressionAlgorithm.Compress(payload)
if err != nil {
logger.Error(err.Error())
}
encryptedPayload, err := slidingWindow.streamSettings.EncryptionAlgorithm.Encrypt(compressedPayload)
if err != nil {
logger.Error(err.Error())
}
packetCopy.SetPayload(encryptedPayload)
} else {
// * PRUDPLite does not encrypt payloads, since they go over WSS
if packetCopy.Version() != 2 {
packetCopy.SetPayload(packetCopy.processUnreliableCrypto())
}
}
}
if ps.PRUDPV1Settings.LegacyConnectionSignature {
packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.Signature))
} else {
packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature))
}
packetCopy.incrementSendCount()
packetCopy.setSentAt(time.Now())
if packetCopy.HasFlag(constants.PacketFlagReliable) && packetCopy.HasFlag(constants.PacketFlagNeedsAck) {
slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID())
slidingWindow.TimeoutManager.SchedulePacketTimeout(packetCopy)
}
ps.sendRaw(packetCopy.Sender().(*PRUDPConnection).Socket, packetCopy.Bytes())
}
// sendRaw will send the given socket the provided packet
func (ps *PRUDPServer) sendRaw(socket *SocketConnection, data []byte) {
// TODO - Should this return the error too?
var err error
if address, ok := socket.Address.(*net.UDPAddr); ok && ps.udpSocket != nil {
_, err = ps.udpSocket.WriteToUDP(data, address)
} else if socket.WebSocketConnection != nil {
err = socket.WebSocketConnection.WriteMessage(gws.OpcodeBinary, data)
}
if err != nil {
logger.Error(err.Error())
}
}
// SetFragmentSize sets the max size for a packets payload
func (ps *PRUDPServer) SetFragmentSize(fragmentSize int) {
// TODO - Derive this value from the MTU
// * From the wiki:
// *
// * The fragment size depends on the implementation.
// * It is generally set to the MTU minus the packet overhead.
// *
// * In old NEX versions, which only support PRUDP v0, the MTU is
// * hardcoded to 1000 and the maximum payload size seems to be 962 bytes.
// *
// * Later, the MTU was increased to 1364, and the maximum payload
// * size is seems to be 1300 bytes, unless PRUDP v0 is used, in which case it’s 1264 bytes.
ps.FragmentSize = fragmentSize
}
// NewPRUDPServer will return a new PRUDP server
func NewPRUDPServer() *PRUDPServer {
return &PRUDPServer{
Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](),
SessionKeyLength: 32,
FragmentSize: 1300,
LibraryVersions: NewLibraryVersions(),
ByteStreamSettings: NewByteStreamSettings(),
PRUDPV0Settings: NewPRUDPV0Settings(),
PRUDPV1Settings: NewPRUDPV1Settings(),
}
}