diff --git a/client.go b/client.go index e30db2dd..bd6e9982 100644 --- a/client.go +++ b/client.go @@ -1,40 +1,80 @@ package turn import ( + b64 "encoding/base64" "fmt" + "math" "net" - "sync" + "time" + "github.com/gortc/turn" "github.com/pion/logging" "github.com/pion/stun" "github.com/pion/transport/vnet" + "github.com/pion/turn/internal/client" "github.com/pkg/errors" ) +const ( + defaultRTO = 200 * time.Millisecond + maxRtxCount = 7 // total 7 requests (Rc) + maxDataBufferSize = math.MaxUint16 //message size limit for Chromium +) + +// interval [msec] +// 0: 0 ms +500 +// 1: 500 ms +1000 +// 2: 1500 ms +2000 +// 3: 3500 ms +4000 +// 4: 7500 ms +8000 +// 5: 15500 ms +16000 +// 6: 31500 ms +32000 +// -: 63500 ms failed + // ClientConfig is a bag of config parameters for Client. type ClientConfig struct { - ListeningAddress string - LoggerFactory logging.LoggerFactory - Net *vnet.Net - Software *stun.Software - Sender Sender + STUNServerAddr string // STUN server address (e.g. "stun.abc.com:3478") + TURNServerAddr string // TURN server addrees (e.g. "turn.abc.com:3478") + Username string + Password string + Realm string + Software *stun.Software + RTO time.Duration + Conn net.PacketConn // Listening socket (net.PacketConn) + LoggerFactory logging.LoggerFactory + Net *vnet.Net } // Client is a STUN server client type Client struct { - conn net.PacketConn - mux sync.Mutex - net *vnet.Net - log logging.LeveledLogger - software *stun.Software - sender Sender + conn net.PacketConn // read-only + stunServ net.Addr // read-only + turnServ net.Addr // read-only + stunServStr string // read-only, used for dmuxing + turnServStr string // read-only, used for dmuxing + username stun.Username // read-only + password string // read-only + realm stun.Realm // read-only + integrity stun.MessageIntegrity // read-only + software *stun.Software + trMap *client.TransactionMap // thread-safe + rto time.Duration // read-only + relayedConn *client.UDPConn // protected by mutex *** + allocTryLock client.TryLock // thread-safe + listenTryLock client.TryLock // thread-safe + net *vnet.Net // read-only + mutex sync.RWMutex // thread-safe + log logging.LeveledLogger // read-only } // NewClient returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0" func NewClient(config *ClientConfig) (*Client, error) { log := config.LoggerFactory.NewLogger("turnc") - network := "udp4" + + if config.Conn == nil { + return nil, fmt.Errorf("conn cannot not be nil") + } if config.Net == nil { config.Net = vnet.NewNet(nil) // defaults to native operation @@ -42,56 +82,127 @@ func NewClient(config *ClientConfig) (*Client, error) { log.Warn("vnet is enabled") } - if config.Sender == nil { - config.Sender = defaultBuildAndSend + var stunServ, turnServ net.Addr + var stunServStr, turnServStr string + var err error + if len(config.STUNServerAddr) > 0 { + log.Debugf("resolving %s", config.STUNServerAddr) + stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr) + if err != nil { + return nil, err + } + stunServStr = stunServ.String() + log.Debugf("stunServ: %s", stunServStr) + } + if len(config.TURNServerAddr) > 0 { + log.Debugf("resolving %s", config.TURNServerAddr) + turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr) + if err != nil { + return nil, err + } + turnServStr = turnServ.String() + log.Debugf("turnServ: %s", turnServStr) } c := &Client{ - net: config.Net, - log: log, - software: config.Software, - sender: config.Sender, + conn: config.Conn, + stunServ: stunServ, + turnServ: turnServ, + stunServStr: stunServStr, + turnServStr: turnServStr, + username: stun.NewUsername(config.Username), + password: config.Password, + realm: stun.NewRealm(config.Realm), + software: config.Software, + net: config.Net, + trMap: client.NewTransactionMap(), + rto: defaultRTO, + log: log, } - var err error - c.conn, err = c.net.ListenPacket(network, config.ListeningAddress) - if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("failed to listen on %s", config.ListeningAddress)) + return c, nil +} + +// TURNServerAddr return the TURN server address +func (c *Client) TURNServerAddr() net.Addr { + return c.turnServ +} + +// STUNServerAddr return the STUN server address +func (c *Client) STUNServerAddr() net.Addr { + return c.stunServ +} + +// Username returns username +func (c *Client) Username() stun.Username { + return c.username +} + +// Realm return realm +func (c *Client) Realm() stun.Realm { + return c.realm +} + +// WriteTo sends data to the specified destination using the base socket. +func (c *Client) WriteTo(data []byte, to net.Addr) (int, error) { + return c.conn.WriteTo(data, to) +} + +// Listen will have this client start listening on the conn provided via the config. +// This is optional. If not used, you will need to call HandleInbound method +// to supply incoming data, instead. +func (c *Client) Listen() error { + if err := c.listenTryLock.Lock(); err != nil { + return fmt.Errorf("already listening: %s", err.Error()) } - return c, nil + go func() { + buf := make([]byte, maxDataBufferSize) + for { + n, from, err := c.conn.ReadFrom(buf) + if err != nil { + c.log.Debugf("exiting read loop: %s", err.Error()) + break + } + + _, err = c.HandleInbound(buf[:n], from) + if err != nil { + c.log.Debugf("exiting read loop: %s", err.Error()) + break + } + } + + c.listenTryLock.Unlock() + }() + + return nil } -// SendSTUNRequest sends a new STUN request to the serverIP with serverPort -func (c *Client) SendSTUNRequest(serverIP net.IP, serverPort int) (net.Addr, error) { - c.mux.Lock() - defer c.mux.Unlock() +// Close closes this client +func (c *Client) Close() { + c.trMap.CloseAndDeleteAll() +} - c.log.Debug("sending STUN request") +// TransactionID & Base64: https://play.golang.org/p/EEgmJDI971P +// SendBindingRequestTo sends a new STUN request to the given transport address +func (c *Client) SendBindingRequestTo(to net.Addr) (net.Addr, error) { attrs := []stun.Setter{stun.TransactionID, stun.BindingRequest} if c.software != nil { attrs = append(attrs, *c.software) } - if err := c.sender(c.conn, &net.UDPAddr{IP: serverIP, Port: serverPort}, attrs...); err != nil { - return nil, err - } - packet := make([]byte, 1500) - c.log.Debug("wait for STUN response...") - size, _, err := c.conn.ReadFrom(packet) + msg, err := stun.Build(attrs...) if err != nil { - return nil, errors.Wrap(err, "failed to read packet from udp socket") + return nil, err } - - c.log.Debugf("received %d bytes of STUN response", size) - resp := &stun.Message{Raw: append([]byte{}, packet[:size]...)} - if err := resp.Decode(); err != nil { - return nil, errors.Wrap(err, "failed to handle reply") + trRes, err := c.PerformTransaction(msg, to, false) + if err != nil { + return nil, err } var reflAddr stun.XORMappedAddress - if err := reflAddr.GetFrom(resp); err != nil { + if err := reflAddr.GetFrom(trRes.Msg); err != nil { return nil, err } @@ -101,3 +212,332 @@ func (c *Client) SendSTUNRequest(serverIP net.IP, serverPort int) (net.Addr, err Port: reflAddr.Port, }, nil } + +// SendBindingRequest sends a new STUN request to the STUN server +func (c *Client) SendBindingRequest() (net.Addr, error) { + if c.stunServ == nil { + return nil, fmt.Errorf("STUN server address is not set for the client") + } + return c.SendBindingRequestTo(c.stunServ) +} + +// Allocate sends a TURN allocation request to the given transport address +func (c *Client) Allocate() (net.PacketConn, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("only one Allocate() caller is allowed: %s", err.Error()) + } + defer c.allocTryLock.Unlock() + + relayedConn := c.relayedUDPConn() + if relayedConn != nil { + return nil, fmt.Errorf("already allocated at %s", relayedConn.LocalAddr().String()) + } + + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + turn.RequestedTransportUDP, + stun.Fingerprint, + ) + if err != nil { + return nil, err + } + + trRes, err := c.PerformTransaction(msg, c.turnServ, false) + if err != nil { + return nil, err + } + + res := trRes.Msg + + // Anonymous allocate failed, trying to authenticate. + var nonce stun.Nonce + if err = nonce.GetFrom(res); err != nil { + return nil, err + } + if err = c.realm.GetFrom(res); err != nil { + return nil, err + } + c.realm = append([]byte(nil), c.realm...) + c.integrity = stun.NewLongTermIntegrity( + c.username.String(), c.realm.String(), c.password, + ) + // Trying to authorize. + msg, err = stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + turn.RequestedTransportUDP, + &c.username, + &c.realm, + &nonce, + &c.integrity, + stun.Fingerprint, + ) + if err != nil { + return nil, err + } + + trRes, err = c.PerformTransaction(msg, c.turnServ, false) + if err != nil { + return nil, err + } + res = trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return nil, fmt.Errorf("%s (error %s)", res.Type, code) + } + return nil, fmt.Errorf("%s", res.Type) + } + + // Getting relayed addresses from response. + var relayed turn.RelayedAddress + if err := relayed.GetFrom(res); err != nil { + return nil, err + } + relayedAddr := &net.UDPAddr{ + IP: relayed.IP, + Port: relayed.Port, + } + + // Getting lifetime from response + var lifetime turn.Lifetime + if err := lifetime.GetFrom(res); err != nil { + return nil, err + } + + relayedConn = client.NewUDPConn(&client.UDPConnConfig{ + Observer: c, + RelayedAddr: relayedAddr, + Integrity: c.integrity, + Nonce: nonce, + Lifetime: lifetime.Duration, + Log: c.log, + }) + + c.setRelayedUDPConn(relayedConn) + + return relayedConn, nil +} + +// PerformTransaction performs STUN transaction +func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (client.TransactionResult, error) { + trKey := b64.StdEncoding.EncodeToString(msg.TransactionID[:]) + tr := client.NewTransaction(&client.TransactionConfig{ + Key: trKey, + Raw: msg.Raw, + To: to, + Interval: c.rto, + }) + + c.trMap.Insert(trKey, tr) + + c.log.Tracef("start %s transaction %s to %s", msg.Type, trKey, tr.To.String()) + _, err := c.conn.WriteTo(msg.Raw, to) + if err != nil { + return client.TransactionResult{}, err + } + + tr.StartRtxTimer(c.onRtxTimeout) + + // If dontWait is true, get the transaction going and return immediately + if dontWait { + return client.TransactionResult{}, nil + } + + res := tr.WaitForResult() + if res.Err != nil { + return res, res.Err + } + return res, nil +} + +// OnDeallocated is called when deallocation of relay address has been complete. +// (Called by UDPConn) +func (c *Client) OnDeallocated(relayedAddr net.Addr) { + c.setRelayedUDPConn(nil) +} + +// HandleInbound handles data received. +// This method handles incoming packet demultiplex it by the source address +// and the types of the message. +// This return a booleen (handled or not) and if there was an error. +// Caller should check if the packet was handled by this client or not. +// If not handled, it is assumed that the packet is application data. +// If an error is returned, the caller should discard the packet regardless. +func (c *Client) HandleInbound(data []byte, from net.Addr) (bool, error) { + var handled bool + var err error + + switch { + case stun.IsMessage(data): + handled = true + err = c.handleSTUNMessage(data, from) + case len(c.turnServStr) != 0 && from.String() == c.turnServStr: + handled = true + // received from TURN server + if turn.IsChannelData(data) { + err = c.handleChannelData(data) + } else { + err = fmt.Errorf("unexpected packet from TURN server") + } + case len(c.stunServStr) != 0 && from.String() == c.stunServStr: + handled = true + // received from STUN server but it is not a STUN message + err = fmt.Errorf("non-STUN message from STUN server") + default: + // assume, this is an application data + c.log.Tracef("non-STUN/TURN packect, unhandled") + } + + // +---------+---------+-------------------------------+ + // | handled | err | Meaning / Action | + // |=========+=========+===============================+ + // | false | nil | Handle the packet as app data | + // |---------+---------+-------------------------------+ + // | true | nil | Nothing to do | + // |---------+---------+-------------------------------+ + // | false | error | (shouldn't happen) | + // |---------+---------+-------------------------------+ + // | true | error | Error occurred while handling | + // +---------+---------+-------------------------------+ + // Possible causes of the error: + // - Malformed packet (parse error) + // - STUN message was a request + // - Non-STUN message from the STUN server + + return handled, err +} + +func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { + msg := &stun.Message{Raw: data} + if err := msg.Decode(); err != nil { + return errors.Wrap(err, "failed to decode STUN message") + } + + if msg.Type.Class == stun.ClassRequest { + return fmt.Errorf("unpexpected STUN request message: %s", msg.String()) + } + + if msg.Type.Class == stun.ClassIndication { + if msg.Type.Method == stun.MethodData { + var peerAddr turn.PeerAddress + if err := peerAddr.GetFrom(msg); err != nil { + return err + } + from = &net.UDPAddr{ + IP: peerAddr.IP, + Port: peerAddr.Port, + } + + var data turn.Data + if err := data.GetFrom(msg); err != nil { + return err + } + + c.log.Debugf("data indication received from %s", from.String()) + + relayedConn := c.relayedUDPConn() + if relayedConn == nil { + c.log.Debug("no relayed conn allocated") + return nil // silently discard + } + relayedConn.HandleInbound(data, from) + } + return nil + } + + // This is a STUN response message (transactional) + // The type is either: + // - stun.ClassSuccessResponse + // - stun.ClassErrorResponse + + trKey := b64.StdEncoding.EncodeToString(msg.TransactionID[:]) + tr, ok := c.trMap.Find(trKey) + if !ok { + // silently discard + c.log.Debugf("no transaction for %s", msg.String()) + return nil + } + + // End the transaction + tr.StopRtxTimer() + c.trMap.Delete(trKey) + + if !tr.WriteResult(client.TransactionResult{Msg: msg, From: from}) { + c.log.Debugf("no listener for %s", msg.String()) + } + + return nil +} + +func (c *Client) handleChannelData(data []byte) error { + chData := &turn.ChannelData{ + Raw: make([]byte, len(data)), + } + copy(chData.Raw, data) + if err := chData.Decode(); err != nil { + return err + } + + relayedConn := c.relayedUDPConn() + if relayedConn != nil { + c.log.Debug("no relayed conn allocated") + return nil // silently discard + } + + addr, ok := relayedConn.FindAddrByChannelNumber(uint16(chData.Number)) + if !ok { + return fmt.Errorf("binding with channel %d not found", int(chData.Number)) + } + + relayedConn.HandleInbound(chData.Data, addr) + return nil +} + +func (c *Client) onRtxTimeout(trKey string, nRtx int32) { + tr, ok := c.trMap.Find(trKey) + if !ok { + return // already gone + } + + if nRtx == maxRtxCount { + // all retransmisstions failed + c.trMap.Delete(trKey) + if !tr.WriteResult(client.TransactionResult{ + Err: fmt.Errorf("all retransmissions for %s failed", trKey), + }) { + c.log.Debug("no listener for transaction") + } + return + } + + c.log.Tracef("retransmitting transaction %s to %s (nRtx=%d)", + trKey, tr.To.String(), nRtx) + _, err := c.conn.WriteTo(tr.Raw, tr.To) + if err != nil { + c.trMap.Delete(trKey) + if !tr.WriteResult(client.TransactionResult{ + Err: fmt.Errorf("failed to retransmit transaction %s", trKey), + }) { + c.log.Debug("no listener for transaction") + } + return + } + tr.StartRtxTimer(c.onRtxTimeout) +} + +func (c *Client) setRelayedUDPConn(conn *client.UDPConn) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.relayedConn = conn +} + +func (c *Client) relayedUDPConn() *client.UDPConn { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.relayedConn +} diff --git a/client_test.go b/client_test.go index 7e81e644..2b16fee9 100644 --- a/client_test.go +++ b/client_test.go @@ -3,25 +3,82 @@ package turn import ( "net" "testing" + "time" "github.com/pion/stun" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/pion/logging" + "github.com/stretchr/testify/assert" ) -func TestClient(t *testing.T) { +func createListeningTestClient(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, bool) { + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") + if !assert.NoError(t, err, "should succeed") { + return nil, false + } + c, err := NewClient(&ClientConfig{ + Conn: conn, + Software: &stun.NewSoftware("TEST SOFTWARE"), + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return nil, false + } + err = c.Listen() + if !assert.NoError(t, err, "should succeed") { + return nil, false + } + + return c, true +} + +func createListeningTestClientWithSTUNServ(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, bool) { + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") + if !assert.NoError(t, err, "should succeed") { + return nil, false + } + c, err := NewClient(&ClientConfig{ + STUNServerAddr: "stun1.l.google.com:19302", + Conn: conn, + + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return nil, false + } + err = c.Listen() + if !assert.NoError(t, err, "should succeed") { + return nil, false + } + + return c, true +} + +func TestClientWithSTUN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") - t.Run("SendSTUNRequest Parallel", func(t *testing.T) { - c, err := NewClient(&ClientConfig{ - ListeningAddress: "0.0.0.0:0", - LoggerFactory: loggerFactory, - }) - if err != nil { - t.Fatal(err) + t.Run("SendBindingRequest", func(t *testing.T) { + c, ok := createListeningTestClientWithSTUNServ(t, loggerFactory) + if !ok { + return } + defer c.Close() + + resp, err := c.SendBindingRequest() + assert.NoError(t, err, "should succeed") + log.Debugf("mapped-addr: %s", resp.String()) + assert.Equal(t, 0, c.trMap.Size(), "should be no transaction left") + }) + + t.Run("SendBindingRequestTo Parallel", func(t *testing.T) { + c, ok := createListeningTestClient(t, loggerFactory) + if !ok { + return + } + defer c.Close() // simple channel fo go routine start signaling started := make(chan struct{}) @@ -29,10 +86,15 @@ func TestClient(t *testing.T) { var err1 error var resp1 interface{} + to, err := net.ResolveUDPAddr("udp4", "stun1.l.google.com:19302") + if !assert.NoError(t, err, "should succeed") { + return + } + // stun1.l.google.com:19302, more at https://gist.github.com/zziuni/3741933#file-stuns-L5 go func() { close(started) - resp1, err1 = c.SendSTUNRequest(net.IPv4(74, 125, 143, 127), 19302) + resp1, err1 = c.SendBindingRequestTo(to) close(finished) }() @@ -40,7 +102,7 @@ func TestClient(t *testing.T) { <-started - resp2, err2 := c.SendSTUNRequest(net.IPv4(74, 125, 143, 127), 19302) + resp2, err2 := c.SendBindingRequestTo(to) if err2 != nil { t.Fatal(err) } else { @@ -55,61 +117,28 @@ func TestClient(t *testing.T) { } }) - t.Run("SendSTUNRequest adds SOFTWARE attribute to message", func(t *testing.T) { - const testSoftware = "CLIENT_SOFTWARE" - - cfg := &ClientConfig{ - ListeningAddress: "0.0.0.0:0", - LoggerFactory: loggerFactory, - Sender: func(conn net.PacketConn, addr net.Addr, attrs ...stun.Setter) error { - msg, err := stun.Build(attrs...) - if err != nil { - return errors.Wrap(err, "could not build message") - } - var software stun.Software - if err = software.GetFrom(msg); err != nil { - return errors.Wrap(err, "could not get SOFTWARE attribute") - } - - assert.Equal(t, testSoftware, software.String()) - - // just forward to the default sender. - return defaultBuildAndSend(conn, addr, attrs...) - }, - } - software := stun.NewSoftware(testSoftware) - cfg.Software = &software + t.Run("NewClient should fail if Conn is nil", func(t *testing.T) { + _, err := NewClient(&ClientConfig{ + LoggerFactory: loggerFactory, + }) + assert.Error(t, err, "should fail") + }) - c, err := NewClient(cfg) - if err != nil { - t.Fatal(err) - } - if _, err = c.SendSTUNRequest(net.IPv4(74, 125, 143, 127), 19302); err != nil { - t.Fatal(err) + t.Run("SendBindingRequestTo timeout", func(t *testing.T) { + c, ok := createListeningTestClient(t, loggerFactory) + if !ok { + return } - }) + defer c.Close() - t.Run("Listen error", func(t *testing.T) { - _, err := NewClient(&ClientConfig{ - ListeningAddress: "255.255.255.256:65535", - LoggerFactory: loggerFactory, - }) - if err == nil { - t.Fatal("listening on 255.255.255.256:65535 should fail") + to, err := net.ResolveUDPAddr("udp4", "127.0.0.1:9") + if !assert.NoError(t, err, "should succeed") { + return } - }) - /* - // Unable to perform this test atm because there is no timeout and the test may run infinitely - t.Run("SendSTUNRequest timeout", func(t *testing.T) { - c, err := NewClient("0.0.0.0:0") - if err != nil { - t.Fatal(err) - } - _, err = c.SendSTUNRequest(net.IPv4(255, 255, 255, 255), 65535) - if err == nil { - t.Fatal("request to 255.255.255.255:65535 should fail") - } - }) - */ + c.rto = 10 * time.Millisecond // force short timeout + + _, err = c.SendBindingRequestTo(to) + log.Debug(err.Error()) + }) } diff --git a/cmd/client/main.go b/cmd/client/main.go index 15d75ad3..1d4aedc4 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -1,8 +1,8 @@ package main import ( - "errors" "flag" + "fmt" "log" "net" @@ -12,19 +12,25 @@ import ( ) func main() { - host := flag.String("host", "74.125.143.127", "IP of TURN Server. Default is the IP of stun1.l.google.com.") + host := flag.String("host", "stun1.l.google.com", "IP of TURN Server. Default is stun1.l.google.com.") port := flag.Int("port", 19302, "Port of TURN server.") software := flag.String("software", "", "The STUN SOFTWARE attribute. Useful for debugging purpose.") flag.Parse() - ip := net.ParseIP(*host) - if ip == nil { - panic(errors.New("failed to parse host IP")) + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") + if err != nil { + panic(err) } + defer func() { + if err2 := conn.Close(); err2 != nil { + panic(err2) + } + }() cfg := &turn.ClientConfig{ - ListeningAddress: "0.0.0.0:0", - LoggerFactory: logging.NewDefaultLoggerFactory(), + STUNServerAddr: fmt.Sprintf("%s:%d", *host, *port), + Conn: conn, + LoggerFactory: logging.NewDefaultLoggerFactory(), } if *software != "" { @@ -36,8 +42,14 @@ func main() { if err != nil { panic(err) } + defer c.Close() + + err = c.Listen() + if err != nil { + panic(err) + } - mappedAddr, err := c.SendSTUNRequest(ip, *port) + mappedAddr, err := c.SendBindingRequest() if err != nil { panic(err) } diff --git a/go.mod b/go.mod index b0bb12fa..9c86f44d 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.12 require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/gortc/turn v0.8.0 - github.com/pion/logging v0.2.1 + github.com/pion/logging v0.2.2 github.com/pion/stun v0.3.1 - github.com/pion/transport v0.8.1 + github.com/pion/transport v0.8.4 github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/net v0.0.0-20190403144856-b630fd6fe46b diff --git a/go.sum b/go.sum index ee68f450..1babba28 100644 --- a/go.sum +++ b/go.sum @@ -5,10 +5,12 @@ github.com/gortc/turn v0.8.0 h1:WWQi1jkoPmc2E7qgUMcZleveKikT9Ksi3QGIl8ZtY3Q= github.com/gortc/turn v0.8.0/go.mod h1:gvguwaGAFyv5/9KrcW9MkCgHALYD+e99mSM7pSCYYho= github.com/pion/logging v0.2.1 h1:LwASkBKZ+2ysGJ+jLv1E/9H1ge0k1nTfi1X+5zirkDk= github.com/pion/logging v0.2.1/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= +github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/stun v0.3.1 h1:d09JJzOmOS8ZzIp8NppCMgrxGZpJ4Ix8qirfNYyI3BA= github.com/pion/stun v0.3.1/go.mod h1:xrCld6XM+6GWDZdvjPlLMsTU21rNxnO6UO8XsAvHr/M= -github.com/pion/transport v0.8.1 h1:FUHJFd4MaIEJmlpiGx+ZH8j9JLsERnROHQPA9zNFFAs= -github.com/pion/transport v0.8.1/go.mod h1:nAmRRnn+ArVtsoNuwktvAD+jrjSD7pA+H3iRmZwdUno= +github.com/pion/transport v0.8.4 h1:Wios3j8IFmrli4pHiXhGMVnj1DYWiukcboZGSv8kj2M= +github.com/pion/transport v0.8.4/go.mod h1:nAmRRnn+ArVtsoNuwktvAD+jrjSD7pA+H3iRmZwdUno= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/allocation/allocation.go b/internal/allocation/allocation.go index f78bbd37..a58fea1b 100644 --- a/internal/allocation/allocation.go +++ b/internal/allocation/allocation.go @@ -253,7 +253,9 @@ func (a *Allocation) packetHandler(m *Manager) { if err != nil { a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) } - a.log.Debugf("relaying message to client at %s", a.fiveTuple.SrcAddr.String()) + a.log.Debugf("relaying message from %s to client at %s", + srcAddr.String(), + a.fiveTuple.SrcAddr.String()) if _, err = a.TurnSocket.WriteTo(msg.Raw, a.fiveTuple.SrcAddr); err != nil { a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) } diff --git a/internal/client/atomic_bool.go b/internal/client/atomic_bool.go new file mode 100644 index 00000000..f4dbdda9 --- /dev/null +++ b/internal/client/atomic_bool.go @@ -0,0 +1,39 @@ +package client + +import ( + "sync/atomic" +) + +// AtomicBool is an atomic boolean struct +type AtomicBool struct { + n int32 +} + +// NewAtomicBool creates a new instance of AtomicBool +func NewAtomicBool(initiallyTrue bool) *AtomicBool { + var n int32 + if initiallyTrue { + n = 1 + } + return &AtomicBool{n: n} +} + +// SetToTrue sets this value to true +func (b *AtomicBool) SetToTrue() { + atomic.StoreInt32(&b.n, 1) +} + +// SetToFalse sets this value to false +func (b *AtomicBool) SetToFalse() { + atomic.StoreInt32(&b.n, 0) +} + +// True returns true if it is set to true +func (b *AtomicBool) True() bool { + return atomic.LoadInt32(&b.n) != int32(0) +} + +// False return true if it is set to false +func (b *AtomicBool) False() bool { + return atomic.LoadInt32(&b.n) == int32(0) +} diff --git a/internal/client/atomic_bool_test.go b/internal/client/atomic_bool_test.go new file mode 100644 index 00000000..2c35dcf6 --- /dev/null +++ b/internal/client/atomic_bool_test.go @@ -0,0 +1,24 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicBool(t *testing.T) { + b0 := NewAtomicBool(false) + assert.False(t, b0.True(), "should false") + assert.True(t, b0.False(), "should false") + + b1 := NewAtomicBool(true) + assert.True(t, b1.True(), "should true") + assert.False(t, b1.False(), "should true") + + b0.SetToTrue() + assert.True(t, b0.True(), "should true") + assert.False(t, b0.False(), "should true") + b0.SetToFalse() + assert.False(t, b0.True(), "should false") + assert.True(t, b0.False(), "should false") +} diff --git a/internal/client/binding.go b/internal/client/binding.go new file mode 100644 index 00000000..d691dcfa --- /dev/null +++ b/internal/client/binding.go @@ -0,0 +1,131 @@ +package client + +import ( + "net" + "sync" + "sync/atomic" +) + +// Chanel number: +// 0x4000 through 0x7FFF: These values are the allowed channel +// numbers (16,383 possible values). +const ( + minChannelNumber uint16 = 0x4000 + maxChannelNumber uint16 = 0x7fff +) + +type bindingState int32 + +const ( + bindingStateIdle bindingState = iota + bindingStateReady + bindingStateFailed +) + +type binding struct { + number uint16 // read-only + st bindingState // thread-safe (atomic op) + addr net.Addr // read-only + mgr *bindingManager // read-only + mutex sync.Mutex // thread-safe, used in UDPConn +} + +func (b *binding) setState(state bindingState) { + atomic.StoreInt32((*int32)(&b.st), int32(state)) +} + +func (b *binding) state() bindingState { + return bindingState(atomic.LoadInt32((*int32)(&b.st))) +} + +// Thread-safe binding map +type bindingManager struct { + chanMap map[uint16]*binding + addrMap map[string]*binding + next uint16 + mutex sync.RWMutex +} + +func newBindingManager() *bindingManager { + return &bindingManager{ + chanMap: map[uint16]*binding{}, + addrMap: map[string]*binding{}, + next: minChannelNumber, + } +} + +func (mgr *bindingManager) assignChannelNumber() uint16 { + n := mgr.next + if mgr.next == maxChannelNumber { + mgr.next = minChannelNumber + } else { + mgr.next++ + } + return n +} + +func (mgr *bindingManager) create(addr net.Addr) *binding { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b := &binding{ + number: mgr.assignChannelNumber(), + addr: addr, + mgr: mgr, + } + + mgr.chanMap[b.number] = b + mgr.addrMap[b.addr.String()] = b + return b +} + +func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + b, ok := mgr.addrMap[addr.String()] + return b, ok +} + +func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + b, ok := mgr.chanMap[number] + return b, ok +} + +func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b, ok := mgr.addrMap[addr.String()] + if !ok { + return false + } + + delete(mgr.addrMap, addr.String()) + delete(mgr.chanMap, b.number) + return true +} + +func (mgr *bindingManager) deleteByNumber(number uint16) bool { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b, ok := mgr.chanMap[number] + if !ok { + return false + } + + delete(mgr.addrMap, b.addr.String()) + delete(mgr.chanMap, number) + return true +} + +func (mgr *bindingManager) size() int { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + return len(mgr.chanMap) +} diff --git a/internal/client/binding_test.go b/internal/client/binding_test.go new file mode 100644 index 00000000..31e4e386 --- /dev/null +++ b/internal/client/binding_test.go @@ -0,0 +1,75 @@ +package client + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBindingManager(t *testing.T) { + t.Run("number assignment", func(t *testing.T) { + m := newBindingManager() + var n uint16 + for i := uint16(0); i < 10; i++ { + n = m.assignChannelNumber() + assert.Equal(t, minChannelNumber+i, n, "should match") + } + + m.next = uint16(0x7ff0) + for i := uint16(0); i < 16; i++ { + n = m.assignChannelNumber() + assert.Equal(t, 0x7ff0+i, n, "should match") + } + // back to min + n = m.assignChannelNumber() + assert.Equal(t, minChannelNumber, n, "should match") + }) + + t.Run("method test", func(t *testing.T) { + lo := net.IPv4(127, 0, 0, 1) + count := 100 + m := newBindingManager() + for i := 0; i < count; i++ { + addr := &net.UDPAddr{IP: lo, Port: 10000 + i} + b0 := m.create(addr) + b1, ok := m.findByAddr(addr) + assert.True(t, ok, "should succeed") + b2, ok := m.findByNumber(b0.number) + assert.True(t, ok, "should succeed") + + assert.Equal(t, b0, b1, "should match") + assert.Equal(t, b0, b2, "should match") + } + + assert.Equal(t, count, m.size(), "should match") + assert.Equal(t, count, len(m.addrMap), "should match") + + for i := 0; i < count; i++ { + addr := &net.UDPAddr{IP: lo, Port: 10000 + i} + if i%2 == 0 { + assert.True(t, m.deleteByAddr(addr), "should return true") + } else { + assert.True(t, m.deleteByNumber(minChannelNumber+uint16(i)), "should return true") + } + } + + assert.Equal(t, 0, m.size(), "should match") + assert.Equal(t, 0, len(m.addrMap), "should match") + }) + + t.Run("failure test", func(t *testing.T) { + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7777} + m := newBindingManager() + var ok bool + _, ok = m.findByAddr(addr) + assert.False(t, ok, "should fail") + _, ok = m.findByNumber(uint16(5555)) + assert.False(t, ok, "should fail") + ok = m.deleteByAddr(addr) + assert.False(t, ok, "should fail") + ok = m.deleteByNumber(uint16(5555)) + assert.False(t, ok, "should fail") + + }) +} diff --git a/internal/client/conn.go b/internal/client/conn.go new file mode 100644 index 00000000..af191ba4 --- /dev/null +++ b/internal/client/conn.go @@ -0,0 +1,520 @@ +package client + +import ( + "fmt" + "io" + "math" + "net" + "sync" + "time" + + "github.com/gortc/turn" + "github.com/pion/logging" + "github.com/pion/stun" +) + +const ( + maxReadQueueSize = 1024 + permRefreshInterval = 120 * time.Second +) + +const ( + timerIDRefreshAlloc int = iota + timerIDRefreshPerms +) + +func noDeadline() time.Time { + return time.Time{} +} + +type inboundData struct { + data []byte + from net.Addr +} + +// UDPConnObserver is an interface to UDPConn observer +type UDPConnObserver interface { + TURNServerAddr() net.Addr + Username() stun.Username + Realm() stun.Realm + WriteTo(data []byte, to net.Addr) (int, error) + PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) + OnDeallocated(relayedAddr net.Addr) +} + +// UDPConnConfig is a set of configuration params use by NewUDPConn +type UDPConnConfig struct { + Observer UDPConnObserver + RelayedAddr net.Addr + Integrity stun.MessageIntegrity + Nonce stun.Nonce + Lifetime time.Duration + Log logging.LeveledLogger +} + +// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. +// comatible with net.PacketConn and net.Conn +type UDPConn struct { + obs UDPConnObserver // read-only + relayedAddr net.Addr // read-only + permMap *permissionMap // thread-safe + bindingMgr *bindingManager // thread-safe + integrity stun.MessageIntegrity // read-only + nonce stun.Nonce // read-only + lifetime time.Duration // needs mutex x + readCh chan *inboundData // thread-safe + closeCh chan struct{} // thread-safe + closed *AtomicBool // thread-safe + readTimer *time.Timer // thread-safe + refreshAllocTimer *PeriodicTimer // thread-safe + refreshPermsTimer *PeriodicTimer // thread-safe + mutex sync.RWMutex // thread-safe + log logging.LeveledLogger // read-only +} + +// NewUDPConn creates a new instance of UDPConn +func NewUDPConn(config *UDPConnConfig) *UDPConn { + c := &UDPConn{ + obs: config.Observer, + relayedAddr: config.RelayedAddr, + permMap: newPermissionMap(), + bindingMgr: newBindingManager(), + integrity: config.Integrity, + nonce: config.Nonce, + lifetime: config.Lifetime, + readCh: make(chan *inboundData, maxReadQueueSize), + closeCh: make(chan struct{}), + closed: NewAtomicBool(false), + readTimer: time.NewTimer(time.Duration(math.MaxInt64)), + log: config.Log, + } + + c.log.Debugf("initial lifetime: %d seconds", int(c.lifetime.Seconds())) + + c.refreshAllocTimer = NewPeriodicTimer( + timerIDRefreshAlloc, + c.onRefreshTimers, + c.lifetime/2, + ) + + c.refreshPermsTimer = NewPeriodicTimer( + timerIDRefreshPerms, + c.onRefreshTimers, + permRefreshInterval, + ) + + if c.refreshAllocTimer.Start() { + c.log.Debugf("refreshAllocTimer started") + } + if c.refreshPermsTimer.Start() { + c.log.Debugf("refreshPermsTimer started") + } + + return c +} + +// ReadFrom reads a packet from the connection, +// copying the payload into p. It returns the number of +// bytes copied into p and the return address that +// was on the packet. +// It returns the number of bytes read (0 <= n <= len(p)) +// and any error encountered. Callers should always process +// the n > 0 bytes returned before considering the error err. +// ReadFrom can be made to time out and return +// an Error with Timeout() == true after a fixed time limit; +// see SetDeadline and SetReadDeadline. +func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for c.closed.False() { + select { + case ibData := <-c.readCh: + n := copy(p, ibData.data) + if n < len(ibData.data) { + return 0, nil, io.ErrShortBuffer + } + return n, ibData.from, nil + + case <-c.readTimer.C: + return 0, nil, &net.OpError{ + Op: "read", + Net: c.LocalAddr().Network(), + Addr: c.LocalAddr(), + Err: newTimeoutError("i/o timeout"), + } + + case <-c.closeCh: + c.closed.SetToTrue() + } + } + + return 0, nil, &net.OpError{ + Op: "read", + Net: c.LocalAddr().Network(), + Addr: c.LocalAddr(), + Err: fmt.Errorf("use of closed network connection"), + } +} + +// WriteTo writes a packet with payload p to addr. +// WriteTo can be made to time out and return +// an Error with Timeout() == true after a fixed time limit; +// see SetDeadline and SetWriteDeadline. +// On packet-oriented connections, write timeouts are rare. +func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { + _, ok := addr.(*net.UDPAddr) + if !ok { + return 0, fmt.Errorf("addr is not a net.UDPAddr") + } + + // check if we have a permission for the destination IP addr + perm, ok := c.permMap.find(addr) + if !ok { + perm = &permission{} + c.permMap.insert(addr, perm) + } + + // This func-block would block, per destination IP (, or perm), until + // the perm state becomes "requested". Purpose of this is to guarantee + // the order of packets (within the same perm). + // Note that CreatePermission transaction may not be complete before + // all the data transmission. This is done assuming that the request + // will be mostly likely successful and we can tolerate some loss of + // UDP packet (or reorder), inorder to minimize the latency in most cases. + err := func() error { + perm.mutex.Lock() + defer perm.mutex.Unlock() + + if perm.state() == permStateIdle { + // punch a hole! (this would block a bit..) + if err := c.createPermissions(addr); err != nil { + c.permMap.delete(addr) + return err + } + perm.setState(permStatePermitted) + } + return nil + }() + if err != nil { + return 0, err + } + + // bind channel + + b, ok := c.bindingMgr.findByAddr(addr) + if !ok { + b = c.bindingMgr.create(addr) + } + if b.state() != bindingStateReady { + if b.state() == bindingStateIdle { + func() { + // block only callers with the same binding until + // the binding transaction has been complete + b.mutex.Lock() + defer b.mutex.Unlock() + + // binding state may have been changed while waiting. check again. + if b.state() == bindingStateIdle { + err = c.bind(b) + if err != nil { + c.log.Warnf("bind() failed: %s", err.Error()) + b.setState(bindingStateFailed) + // keep going... + // TODO: consider try binding again after a while + } else { + b.setState(bindingStateReady) + } + } + }() + } + + // send data using SendIndication + // TODO: send over channel when it becomes available + peerAddr := addr2PeerAddress(addr) + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodSend, stun.ClassIndication), + turn.RequestedTransportUDP, + turn.Data(p), + peerAddr, + stun.Fingerprint, + ) + if err != nil { + return 0, err + } + + // indication has no transaction (fire-and-forget) + + return c.obs.WriteTo(msg.Raw, c.obs.TURNServerAddr()) + } + + // send via ChannelData + return c.sendChannelData(p, b.number) +} + +// Close closes the connection. +// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. +func (c *UDPConn) Close() error { + c.refreshAllocTimer.Stop() + c.refreshPermsTimer.Stop() + + select { + case <-c.closeCh: + return fmt.Errorf("already closed") + default: + close(c.closeCh) + } + + c.refreshAllocation(0, true) // dontWait = true + c.obs.OnDeallocated(c.relayedAddr) + return nil +} + +// LocalAddr returns the local network address. +func (c *UDPConn) LocalAddr() net.Addr { + return c.relayedAddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +// +// A deadline is an absolute time after which I/O operations +// fail with a timeout (see type Error) instead of +// blocking. The deadline applies to all future and pending +// I/O, not just the immediately following call to ReadFrom or +// WriteTo. After a deadline has been exceeded, the connection +// can be refreshed by setting a deadline in the future. +// +// An idle timeout can be implemented by repeatedly extending +// the deadline after successful ReadFrom or WriteTo calls. +// +// A zero value for t means I/O operations will not time out. +func (c *UDPConn) SetDeadline(t time.Time) error { + return c.SetReadDeadline(t) +} + +// SetReadDeadline sets the deadline for future ReadFrom calls +// and any currently-blocked ReadFrom call. +// A zero value for t means ReadFrom will not time out. +func (c *UDPConn) SetReadDeadline(t time.Time) error { + var d time.Duration + if t == noDeadline() { + d = time.Duration(math.MaxInt64) + } else { + d = time.Until(t) + } + c.readTimer.Reset(d) + return nil +} + +// SetWriteDeadline sets the deadline for future WriteTo calls +// and any currently-blocked WriteTo call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means WriteTo will not time out. +func (c *UDPConn) SetWriteDeadline(t time.Time) error { + // Write never blocks. + return nil +} + +func addr2PeerAddress(addr net.Addr) turn.PeerAddress { + var peerAddr turn.PeerAddress + switch a := addr.(type) { + case *net.UDPAddr: + peerAddr.IP = a.IP + peerAddr.Port = a.Port + case *net.TCPAddr: + peerAddr.IP = a.IP + peerAddr.Port = a.Port + } + + return peerAddr +} + +func (c *UDPConn) createPermissions(addrs ...net.Addr) error { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodCreatePermission, stun.ClassRequest), + turn.RequestedTransportUDP, + } + + for _, addr := range addrs { + setters = append(setters, addr2PeerAddress(addr)) + } + + setters = append(setters, + c.obs.Username(), + c.obs.Realm(), + &c.nonce, + &c.integrity, + stun.Fingerprint) + + msg, err := stun.Build(setters...) + if err != nil { + return err + } + + trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), false) + if err != nil { + return err + } + + res := trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + err = fmt.Errorf("%s (error %s)", res.Type, code) + } else { + err = fmt.Errorf("%s", res.Type) + } + return err + } + + return nil +} + +// HandleInbound passes inbound data in UDPConn +func (c *UDPConn) HandleInbound(data []byte, from net.Addr) { + select { + case c.readCh <- &inboundData{data: data, from: from}: + default: + c.log.Warnf("receive buffer full") + } +} + +// FindAddrByChannelNumber returns a peer address associated with the +// channel number on this UDPConn +func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) { + b, ok := c.bindingMgr.findByNumber(chNum) + if !ok { + return nil, false + } + return b.addr, true +} + +func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) { + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodRefresh, stun.ClassRequest), + turn.RequestedTransportUDP, + turn.Lifetime{Duration: lifetime}, + stun.Fingerprint, + ) + if err != nil { + c.log.Errorf("failed to build refresh request: %s", err.Error()) + return + } + + trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), dontWait) + if err != nil { + c.log.Errorf("failed to refresh refresh: %s", err.Error()) + return + } + + if dontWait { + return + } + + // Getting lifetime from response + var updatedLifetime turn.Lifetime + if err := updatedLifetime.GetFrom(trRes.Msg); err != nil { + c.log.Errorf("failed to get lifetime from refresh response: %s", err.Error()) + return + } + + c.mutex.Lock() + c.lifetime = updatedLifetime.Duration + c.log.Debugf("updated lifetime: %d seconds", int(c.lifetime.Seconds())) + c.mutex.Unlock() +} + +func (c *UDPConn) refreshPermissions() { + addrs := c.permMap.addrs() + if len(addrs) == 0 { + c.log.Debug("no permission to refresh") + return + } + if err := c.createPermissions(addrs...); err != nil { + c.log.Errorf("fail to refresh permissions: %s", err.Error()) + return + } + c.log.Debug("refresh permissions successful") +} + +func (c *UDPConn) bind(b *binding) error { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodChannelBind, stun.ClassRequest), + turn.RequestedTransportUDP, + addr2PeerAddress(b.addr), + turn.ChannelNumber(b.number), + c.obs.Username(), + c.obs.Realm(), + c.nonce, + c.integrity, + stun.Fingerprint, + } + + msg, err := stun.Build(setters...) + if err != nil { + return err + } + + trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), false) + if err != nil { + c.bindingMgr.deleteByAddr(b.addr) + } + + res := trRes.Msg + + if res.Type != stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse) { + return fmt.Errorf("unexpected response type %s", res.Type) + } + + c.log.Debugf("channel binding successful: %s %d", + b.addr.String(), + b.number) + + // Success. + return nil +} + +func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) { + chData := &turn.ChannelData{ + Data: data, + Number: turn.ChannelNumber(chNum), + } + chData.Encode() + return c.obs.WriteTo(chData.Raw, c.obs.TURNServerAddr()) +} + +func (c *UDPConn) onRefreshTimers(id int) { + c.log.Debugf("refresh timer %d expired", id) + c.mutex.RLock() + lifetime := c.lifetime + c.mutex.RUnlock() + switch id { + case timerIDRefreshAlloc: + c.refreshAllocation(lifetime, false) + case timerIDRefreshPerms: + c.refreshPermissions() + } +} + +type timeoutError struct { + msg string +} + +func newTimeoutError(msg string) error { + return &timeoutError{ + msg: msg, + } +} + +func (e *timeoutError) Error() string { + return e.msg +} + +func (e *timeoutError) Timeout() bool { + return true +} diff --git a/internal/client/periodic_timer.go b/internal/client/periodic_timer.go new file mode 100644 index 00000000..fcd56787 --- /dev/null +++ b/internal/client/periodic_timer.go @@ -0,0 +1,82 @@ +package client + +import ( + "sync" + "time" +) + +// PeriodicTimerTimeoutHandler is a handler called on timeout +type PeriodicTimerTimeoutHandler func(timerID int) + +// PeriodicTimer is a periodic timer +type PeriodicTimer struct { + id int + interval time.Duration + timeoutHandler PeriodicTimerTimeoutHandler + stopFunc func() + mutex sync.RWMutex +} + +// NewPeriodicTimer create a new timer +func NewPeriodicTimer(id int, timeoutHandler PeriodicTimerTimeoutHandler, interval time.Duration) *PeriodicTimer { + return &PeriodicTimer{ + id: id, + interval: interval, + timeoutHandler: timeoutHandler, + } +} + +// Start starts the timer. +func (t *PeriodicTimer) Start() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + // this is a noop if the timer is always running + if t.stopFunc != nil { + return false + } + + cancelCh := make(chan struct{}) + + go func() { + canceling := false + + for !canceling { + timer := time.NewTimer(t.interval) + + select { + case <-timer.C: + t.timeoutHandler(t.id) + case <-cancelCh: + canceling = true + timer.Stop() + } + } + }() + + t.stopFunc = func() { + close(cancelCh) + } + + return true +} + +// Stop stops the timer. +func (t *PeriodicTimer) Stop() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.stopFunc != nil { + t.stopFunc() + t.stopFunc = nil + } +} + +// IsRunning tests if the timer is running. +// Debug purpose only +func (t *PeriodicTimer) IsRunning() bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + return (t.stopFunc != nil) +} diff --git a/internal/client/periodic_timer_test.go b/internal/client/periodic_timer_test.go new file mode 100644 index 00000000..67c2d01e --- /dev/null +++ b/internal/client/periodic_timer_test.go @@ -0,0 +1,52 @@ +package client + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPriodicTimer(t *testing.T) { + t.Run("basic", func(t *testing.T) { + timerID := 3 + var nCbs int + rt := NewPeriodicTimer(timerID, func(id int) { + nCbs++ + assert.Equal(t, timerID, id) + }, 50*time.Millisecond) + + assert.False(t, rt.IsRunning(), "should not be running yet") + + ok := rt.Start() + assert.True(t, ok, "should be true") + assert.True(t, rt.IsRunning(), "should be running") + + time.Sleep(100 * time.Millisecond) + + ok = rt.Start() + assert.False(t, ok, "start again is noop") + + time.Sleep(120 * time.Millisecond) + rt.Stop() + assert.False(t, rt.IsRunning(), "should not be running") + assert.Equal(t, 4, nCbs, "should be called 4 times (actual: %d)", nCbs) + }) + + t.Run("stop inside handler", func(t *testing.T) { + timerID := 4 + var rt *PeriodicTimer + rt = NewPeriodicTimer(timerID, func(id int) { + assert.Equal(t, timerID, id) + rt.Stop() + }, 20*time.Millisecond) + + assert.False(t, rt.IsRunning(), "should not be running yet") + + ok := rt.Start() + assert.True(t, ok, "should be true") + assert.True(t, rt.IsRunning(), "should be running") + time.Sleep(30 * time.Millisecond) + assert.False(t, rt.IsRunning(), "should not be running") + }) +} diff --git a/internal/client/permission.go b/internal/client/permission.go new file mode 100644 index 00000000..5546a22e --- /dev/null +++ b/internal/client/permission.go @@ -0,0 +1,90 @@ +package client + +import ( + "net" + "sync" + "sync/atomic" +) + +type permState int32 + +const ( + permStateIdle permState = iota + permStatePermitted +) + +type permission struct { + st permState // thread-safe (atomic op) + mutex sync.RWMutex // thread-safe +} + +func (p *permission) setState(state permState) { + atomic.StoreInt32((*int32)(&p.st), int32(state)) +} + +func (p *permission) state() permState { + return permState(atomic.LoadInt32((*int32)(&p.st))) +} + +// Thread-safe permission map +type permissionMap struct { + permMap map[string]*permission + mutex sync.RWMutex +} + +func (m *permissionMap) insert(addr net.Addr, p *permission) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return false + } + + m.permMap[udpAddr.IP.String()] = p + return true +} + +func (m *permissionMap) find(addr net.Addr) (*permission, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return nil, false + } + + p, ok := m.permMap[udpAddr.IP.String()] + return p, ok +} + +func (m *permissionMap) delete(addr net.Addr) { + m.mutex.Lock() + defer m.mutex.Unlock() + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return + } + + delete(m.permMap, udpAddr.IP.String()) +} + +func (m *permissionMap) addrs() []net.Addr { + m.mutex.RLock() + defer m.mutex.RUnlock() + + addrs := []net.Addr{} + for k := range m.permMap { + addrs = append(addrs, &net.UDPAddr{ + IP: net.ParseIP(k), + }) + } + return addrs +} + +func newPermissionMap() *permissionMap { + return &permissionMap{ + permMap: map[string]*permission{}, + } +} diff --git a/internal/client/transaction.go b/internal/client/transaction.go new file mode 100644 index 00000000..d4fb5d29 --- /dev/null +++ b/internal/client/transaction.go @@ -0,0 +1,156 @@ +package client + +import ( + "net" + "sync" + "time" + + "github.com/pion/stun" +) + +const ( + maxRtxInterval time.Duration = 1600 * time.Millisecond +) + +// TransactionResult is a bag of result values of a transaction +type TransactionResult struct { + Msg *stun.Message + From net.Addr + Err error +} + +// TransactionConfig is a set of confi params used by NewTransaction +type TransactionConfig struct { + Key string + Raw []byte + To net.Addr + Interval time.Duration +} + +// Transaction represents a transaction +type Transaction struct { + Key string // read-only + Raw []byte // read-only + To net.Addr // read-only + nRtx int32 // modified only by the timer thread + interval time.Duration // modified only by the timer thread + timer *time.Timer // therad-safe, set only by the creator, and stopper + resultCh chan TransactionResult // thread-safe + mutex sync.RWMutex +} + +// NewTransaction creates a new instance of Transaction +func NewTransaction(config *TransactionConfig) *Transaction { + return &Transaction{ + Key: config.Key, // read-only + Raw: config.Raw, // read-only + To: config.To, // read-only + interval: config.Interval, // modified only by the timer thread + resultCh: make(chan TransactionResult), // thread-safe + } +} + +// StartRtxTimer starts the transaction timer +func (t *Transaction) StartRtxTimer(onTimeout func(trKey string, nRtx int32)) { + t.mutex.Lock() + defer t.mutex.Unlock() + + t.timer = time.AfterFunc(t.interval, func() { + t.nRtx++ + t.interval *= 2 + if t.interval > maxRtxInterval { + t.interval = maxRtxInterval + } + onTimeout(t.Key, t.nRtx) + }) +} + +// StopRtxTimer stop the transaction timer +func (t *Transaction) StopRtxTimer() { + t.mutex.RLock() + defer t.mutex.RUnlock() + + if t != nil { + t.timer.Stop() + } +} + +// WriteResult writes the result to the result channel +func (t *Transaction) WriteResult(res TransactionResult) bool { + select { + case t.resultCh <- res: + return true + default: + } + return false +} + +// WaitForResult waits for the transaction result +func (t *Transaction) WaitForResult() TransactionResult { + return <-t.resultCh +} + +// Close closes the transaction +func (t *Transaction) Close() { + close(t.resultCh) +} + +//////////////////////////////////////////////////////////////////////////////// + +// TransactionMap is a thread-safe transaction map +type TransactionMap struct { + trMap map[string]*Transaction + mutex sync.RWMutex +} + +// NewTransactionMap create a new instance of the transaction map +func NewTransactionMap() *TransactionMap { + return &TransactionMap{ + trMap: map[string]*Transaction{}, + } +} + +// Insert inserts a trasaction to the map +func (m *TransactionMap) Insert(key string, tr *Transaction) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.trMap[key] = tr + return true +} + +// Find looks up a transaction by its key +func (m *TransactionMap) Find(key string) (*Transaction, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + tr, ok := m.trMap[key] + return tr, ok +} + +// Delete deletes a transaction by its key +func (m *TransactionMap) Delete(key string) { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.trMap, key) +} + +// CloseAndDeleteAll closes and deletes all transactions +func (m *TransactionMap) CloseAndDeleteAll() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for trKey, tr := range m.trMap { + tr.Close() + delete(m.trMap, trKey) + } +} + +// Size returns the length of the transaction map +func (m *TransactionMap) Size() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return len(m.trMap) +} diff --git a/internal/client/trylock.go b/internal/client/trylock.go new file mode 100644 index 00000000..4f555345 --- /dev/null +++ b/internal/client/trylock.go @@ -0,0 +1,25 @@ +package client + +import ( + "fmt" + "sync/atomic" +) + +// TryLock implement the classic "try-lock" operation. +type TryLock struct { + n int32 +} + +// Lock tries to lock the try-lock. If successful, it returns true. +// Otherwise, it returns false immedidately. +func (c *TryLock) Lock() error { + if !atomic.CompareAndSwapInt32(&c.n, 0, 1) { + return fmt.Errorf("try-lock is already locked") + } + return nil +} + +// Unlock unlocks the try-lock. +func (c *TryLock) Unlock() { + atomic.StoreInt32(&c.n, 0) +} diff --git a/internal/client/trylock_test.go b/internal/client/trylock_test.go new file mode 100644 index 00000000..74af9973 --- /dev/null +++ b/internal/client/trylock_test.go @@ -0,0 +1,61 @@ +package client + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTryLock(t *testing.T) { + t.Run("success case", func(t *testing.T) { + cl := &TryLock{} + testFunc := func() error { + if err := cl.Lock(); err != nil { + return err + } + defer cl.Unlock() + return nil + } + + err := testFunc() + assert.NoError(t, err, "should succeed") + assert.Equal(t, int32(0), cl.n, "should match") + }) + + t.Run("failure case", func(t *testing.T) { + cl := &TryLock{} + testFunc := func() error { + if err := cl.Lock(); err != nil { + return err + } + defer cl.Unlock() + time.Sleep(50 * time.Millisecond) + return nil + } + + var err1, err2 error + doneCh1 := make(chan struct{}) + doneCh2 := make(chan struct{}) + + go func() { + err1 = testFunc() + close(doneCh1) + }() + go func() { + err2 = testFunc() + close(doneCh2) + }() + + <-doneCh1 + <-doneCh2 + + // Either one of them should fail + if err1 == nil { + assert.Error(t, err2, "should fail") + } else { + assert.Error(t, err1, "should fail") + } + assert.Equal(t, int32(0), cl.n, "should match") + }) +} diff --git a/server.go b/server.go index c98f7884..c368af5c 100644 --- a/server.go +++ b/server.go @@ -321,6 +321,7 @@ func (s *Server) handleUDPPacket(conn net.PacketConn, srcAddr net.Addr, buf []by // caller must hold the mutex func (s *Server) handleDataPacket(conn net.PacketConn, srcAddr net.Addr, buf []byte) error { + s.log.Debugf("received DataPacket from %s", srcAddr.String()) c := turn.ChannelData{Raw: buf} if err := c.Decode(); err != nil { return errors.Wrap(err, "Failed to create channel data from packet") @@ -336,6 +337,7 @@ func (s *Server) handleDataPacket(conn net.PacketConn, srcAddr net.Addr, buf []b // caller must hold the mutex func (s *Server) handleTURNPacket(conn net.PacketConn, srcAddr net.Addr, buf []byte) error { + s.log.Debug("handleTURNPacket") m := &stun.Message{Raw: append([]byte{}, buf...)} if err := m.Decode(); err != nil { return errors.Wrap(err, "failed to create stun message from packet") diff --git a/server_test.go b/server_test.go index 61dcd0a7..bb3ed165 100644 --- a/server_test.go +++ b/server_test.go @@ -17,8 +17,9 @@ func TestServer(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") - credMap := map[string]string{} - credMap["user"] = "pass" + credMap := map[string]string{ + "user": "pass", + } t.Run("simple", func(t *testing.T) { @@ -47,16 +48,26 @@ func TestServer(t *testing.T) { time.Sleep(100 * time.Microsecond) log.Debug("creating a client.") + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") + if !assert.NoError(t, err, "should succeed") { + return + } client, err := NewClient(&ClientConfig{ - ListeningAddress: "0.0.0.0:0", - LoggerFactory: loggerFactory, + Conn: conn, + LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } + err = client.Listen() + if !assert.NoError(t, err, "should succeed") { + return + } + defer client.Close() log.Debug("sending a binding request.") - resp, err := client.SendSTUNRequest(net.IPv4(127, 0, 0, 1), 3478) + to := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478} + resp, err := client.SendBindingRequestTo(to) assert.NoError(t, err, "should succeed") t.Logf("resp: %v", resp) diff --git a/server_vnet_test.go b/server_vnet_test.go index 7efe6ef9..da3e16f8 100644 --- a/server_vnet_test.go +++ b/server_vnet_test.go @@ -10,7 +10,24 @@ import ( "github.com/stretchr/testify/assert" ) -func buildVNet() (*vnet.Router, *vnet.Net, *vnet.Net, error) { +type VNet struct { + wan *vnet.Router + net0 *vnet.Net // net (0) on the WAN + net1 *vnet.Net // net (1) on the WAN + netL0 *vnet.Net // net (0) on the LAN + server *Server +} + +func (v *VNet) Close() error { + err := v.server.Close() + v.wan.Stop() // nolint:errcheck,gosec + if err != nil { + return err + } + return nil +} + +func buildVNet() (*VNet, error) { loggerFactory := logging.NewDefaultLoggerFactory() // WAN @@ -19,16 +36,25 @@ func buildVNet() (*vnet.Router, *vnet.Net, *vnet.Net, error) { LoggerFactory: loggerFactory, }) if err != nil { - return nil, nil, nil, err + return nil, err } - wanNet := vnet.NewNet(&vnet.NetConfig{ + net0 := vnet.NewNet(&vnet.NetConfig{ StaticIP: "1.2.3.4", // will be assigned to eth0 }) - err = wan.AddNet(wanNet) + err = wan.AddNet(net0) if err != nil { - return nil, nil, nil, err + return nil, err + } + + net1 := vnet.NewNet(&vnet.NetConfig{ + StaticIP: "1.2.3.5", // will be assigned to eth0 + }) + + err = wan.AddNet(net1) + if err != nil { + return nil, err } // LAN @@ -42,82 +68,109 @@ func buildVNet() (*vnet.Router, *vnet.Net, *vnet.Net, error) { LoggerFactory: loggerFactory, }) if err != nil { - return nil, nil, nil, err + return nil, err } - lanNet := vnet.NewNet(&vnet.NetConfig{}) - err = lan.AddNet(lanNet) + netL0 := vnet.NewNet(&vnet.NetConfig{}) + err = lan.AddNet(netL0) if err != nil { - return nil, nil, nil, err + return nil, err } err = wan.AddRouter(lan) if err != nil { - return nil, nil, nil, err + return nil, err } err = wan.Start() if err != nil { - return nil, nil, nil, err + return nil, err } - return wan, wanNet, lanNet, nil -} + // start server... + credMap := map[string]string{} + credMap["user"] = "pass" -func TestServerVNet(t *testing.T) { - loggerFactory := logging.NewDefaultLoggerFactory() - log := loggerFactory.NewLogger("test") + server := NewServer(&ServerConfig{ + AuthHandler: func(username string, srcAddr net.Addr) (password string, ok bool) { + if pw, ok := credMap[username]; ok { + return pw, true + } + return "", false + }, + Realm: "pion.ly", + Net: net0, + LoggerFactory: loggerFactory, + }) - t.Run("simple", func(t *testing.T) { - wan, wanNet, lanNet, err := buildVNet() - assert.NoError(t, err, "should succeed") - defer wan.Stop() // nolint:errcheck + err = server.AddListeningIPAddr("1.2.3.4") + if err != nil { + return nil, err + } - credMap := map[string]string{} - credMap["user"] = "pass" + // register host names + err = wan.AddHost("stun.pion.ly", "1.2.3.4") + if err != nil { + return nil, err + } + err = wan.AddHost("turn.pion.ly", "1.2.3.4") + if err != nil { + return nil, err + } + err = wan.AddHost("echo.pion.ly", "1.2.3.5") + if err != nil { + return nil, err + } - server := NewServer(&ServerConfig{ - AuthHandler: func(username string, srcAddr net.Addr) (password string, ok bool) { - if pw, ok := credMap[username]; ok { - return pw, true - } - return "", false - }, - Realm: "pion.ly", - Net: wanNet, - LoggerFactory: loggerFactory, - }) + err = server.Start() + if err != nil { + wan.Stop() // nolint:errcheck,gosec + return nil, err + } - err = server.AddListeningIPAddr("1.2.3.4") - assert.NoError(t, err, "should succeed") + return &VNet{ + wan: wan, + net0: net0, + net1: net1, + netL0: netL0, + server: server, + }, nil +} - doneCh := make(chan struct{}) +func TestServerVNet(t *testing.T) { + loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") - go func() { - log.Debug("start listening...") - err2 := server.Start() - if err2 != nil { - t.Logf("Start returned with err: %v", err2) - } - close(doneCh) - }() + t.Run("SendBindingRequest", func(t *testing.T) { + v, err := buildVNet() + if !assert.NoError(t, err, "should succeed") { + return + } + defer v.Close() // nolint:errcheck - // make sure the server is listening before running - // the client. - time.Sleep(100 * time.Microsecond) + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + if !assert.NoError(t, err, "should succeed") { + return + } log.Debug("creating a client.") client, err := NewClient(&ClientConfig{ - ListeningAddress: "0.0.0.0:0", - Net: lanNet, - LoggerFactory: loggerFactory, + STUNServerAddr: "1.2.3.4:3478", + Conn: lconn, + Net: v.netL0, + LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } + err = client.Listen() + if !assert.NoError(t, err, "should succeed") { + return + } + defer client.Close() log.Debug("sending a binding request.") - reflAddr, err := client.SendSTUNRequest(net.IPv4(1, 2, 3, 4), 3478) + reflAddr, err := client.SendBindingRequest() if !assert.NoError(t, err, "should succeed") { return } @@ -127,11 +180,101 @@ func TestServerVNet(t *testing.T) { // The mapped-address should have IP address that was assigned // to the LAN router. assert.True(t, udpAddr.IP.Equal(net.IPv4(5, 6, 7, 8)), "should match") + }) + + t.Run("Echo via relay", func(t *testing.T) { + v, err := buildVNet() + if !assert.NoError(t, err, "should succeed") { + return + } + defer v.Close() // nolint:errcheck + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + if !assert.NoError(t, err, "should succeed") { + return + } + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + STUNServerAddr: "stun.pion.ly:3478", + TURNServerAddr: "turn.pion.ly:3478", + Username: "user", + Password: "pass", + Conn: lconn, + Net: v.netL0, + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + err = client.Listen() + if !assert.NoError(t, err, "should succeed") { + return + } + defer client.Close() + + log.Debug("sending a binding request.") + conn, err := client.Allocate() + if !assert.NoError(t, err, "should succeed") { + return + } + + log.Debugf("laddr: %s", conn.LocalAddr().String()) + + echoConn, err := v.net1.ListenPacket("udp4", "1.2.3.5:5678") + if !assert.NoError(t, err, "should succeed") { + return + } + defer echoConn.Close() // nolint:errcheck + + go func() { + buf := make([]byte, 1500) + for { + n, from, err2 := echoConn.ReadFrom(buf) + if err2 != nil { + break + } + log.Debugf("echo: received %d bytes from %s: %s", n, from.String(), string(buf[:n])) + + // verify the message was received from the relay address + if !assert.Equal(t, conn.LocalAddr().String(), from.String(), "should match") { + break + } + + // verify the message received is correct + if !assert.Equal(t, "Hello", string(buf[:n]), "should match") { + break + } + + // echo the data + _, err2 = echoConn.WriteTo(buf[:n], from) + if !assert.NoError(t, err2, "should succeed") { + break + } + } + }() + + buf := make([]byte, 1500) + + for i := 0; i < 4; i++ { + log.Debug("sending \"Hello\"..") + _, err = conn.WriteTo([]byte("Hello"), echoConn.LocalAddr()) + if !assert.NoError(t, err, "should succeed") { + return + } + + _, from, err2 := conn.ReadFrom(buf) + assert.NoError(t, err2, "should succeed") + + // verify the message was received from the relay address + assert.Equal(t, echoConn.LocalAddr().String(), from.String(), "should match") + + time.Sleep(200 * time.Millisecond) + } - // Close server - err = server.Close() + err = conn.Close() assert.NoError(t, err, "should succeed") - <-doneCh + time.Sleep(1 * time.Second) // just to see what happens.. }) } diff --git a/stun.go b/stun.go index ae0131f0..becdfb13 100644 --- a/stun.go +++ b/stun.go @@ -9,6 +9,7 @@ import ( // caller must hold the mutex func (s *Server) handleBindingRequest(conn net.PacketConn, srcAddr net.Addr, m *stun.Message) error { + s.log.Debugf("received BindingRequest from %s", srcAddr.String()) ip, port, err := ipnet.AddrIPPort(srcAddr) if err != nil { return err diff --git a/turn.go b/turn.go index d6839c5d..2181b2b4 100644 --- a/turn.go +++ b/turn.go @@ -2,6 +2,7 @@ package turn import ( "crypto/md5" // #nosec + "fmt" "net" "strings" "time" @@ -94,6 +95,7 @@ func assertDontFragment(curriedSend curriedSend, m *stun.Message, attr stun.Sett // https://tools.ietf.org/html/rfc5766#section-6.2 // caller must hold the mutex func (s *Server) handleAllocateRequest(conn net.PacketConn, srcAddr net.Addr, m *stun.Message) error { + s.log.Debugf("received AllocateRequest from %s", srcAddr.String()) dstAddr := conn.LocalAddr() curriedSend := func(class stun.MessageClass, method stun.Method, transactionID [stun.TransactionIDSize]byte, attrs ...stun.Setter) error { return s.sender(conn, srcAddr, s.makeAttrs(transactionID, stun.NewType(method, class), attrs...)...) @@ -271,6 +273,7 @@ func (s *Server) handleAllocateRequest(conn net.PacketConn, srcAddr net.Addr, m // caller must hold the mutex func (s *Server) handleRefreshRequest(conn net.PacketConn, srcAddr net.Addr, m *stun.Message) error { + s.log.Debugf("received RefreshRequest from %s", srcAddr.String()) dstAddr := conn.LocalAddr() curriedSend := func(class stun.MessageClass, method stun.Method, transactionID [stun.TransactionIDSize]byte, attrs ...stun.Setter) error { return s.sender(conn, srcAddr, s.makeAttrs(transactionID, stun.NewType(method, class), attrs...)...) @@ -302,6 +305,7 @@ func (s *Server) handleRefreshRequest(conn net.PacketConn, srcAddr net.Addr, m * // caller must hold the mutex func (s *Server) handleCreatePermissionRequest(conn net.PacketConn, srcAddr net.Addr, m *stun.Message) error { + s.log.Debugf("received CreatePermission from %s", srcAddr.String()) dstAddr := conn.LocalAddr() curriedSend := func(class stun.MessageClass, method stun.Method, transactionID [stun.TransactionIDSize]byte, attrs ...stun.Setter) error { return s.sender(conn, srcAddr, s.makeAttrs(transactionID, stun.NewType(method, class), attrs...)...) @@ -328,6 +332,8 @@ func (s *Server) handleCreatePermissionRequest(conn net.PacketConn, srcAddr net. return err } + s.log.Debugf("adding permission for %s", fmt.Sprintf("%s:%d", + peerAddress.IP.String(), peerAddress.Port)) a.AddPermission(allocation.NewPermission( &net.UDPAddr{ IP: peerAddress.IP, @@ -352,6 +358,7 @@ func (s *Server) handleCreatePermissionRequest(conn net.PacketConn, srcAddr net. // caller must hold the mutex func (s *Server) handleSendIndication(conn net.PacketConn, srcAddr net.Addr, m *stun.Message) error { + s.log.Debugf("received SendIndication from %s", srcAddr.String()) dstAddr := conn.LocalAddr() a := s.manager.GetAllocation(&allocation.FiveTuple{ SrcAddr: srcAddr, @@ -386,6 +393,7 @@ func (s *Server) handleSendIndication(conn net.PacketConn, srcAddr net.Addr, m * // caller must hold the mutex func (s *Server) handleChannelBindRequest(conn net.PacketConn, srcAddr net.Addr, m *stun.Message) error { + s.log.Debugf("received ChannelBindRequest from %s", srcAddr.String()) dstAddr := conn.LocalAddr() errorSend := func(err error, attrs ...stun.Setter) error { sendErr := s.sender(conn, srcAddr, s.makeAttrs(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), attrs...)...) @@ -421,6 +429,9 @@ func (s *Server) handleChannelBindRequest(conn net.PacketConn, srcAddr net.Addr, return errorSend(err, stun.CodeBadRequest) } + s.log.Debugf("binding channel %d to %s", + channel, + fmt.Sprintf("%s:%d", peerAddr.IP.String(), peerAddr.Port)) err = a.AddChannelBind(allocation.NewChannelBind( channel, &net.UDPAddr{IP: peerAddr.IP, Port: peerAddr.Port},