diff --git a/client_test.go b/client_test.go index 0830af34..2e7e9346 100644 --- a/client_test.go +++ b/client_test.go @@ -177,11 +177,6 @@ func TestClientNonceExpiration(t *testing.T) { allocation, err := client.Allocate() assert.NoError(t, err) - server.nonces.Range(func(key, value interface{}) bool { - server.nonces.Delete(key) - return true - }) - _, err = allocation.WriteTo([]byte{0x00}, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}) assert.NoError(t, err) diff --git a/internal/server/errors.go b/internal/server/errors.go index 3a8911c4..ea6da834 100644 --- a/internal/server/errors.go +++ b/internal/server/errors.go @@ -7,8 +7,8 @@ import "errors" var ( errFailedToGenerateNonce = errors.New("failed to generate nonce") + errInvalidNonce = errors.New("invalid nonce") errFailedToSendError = errors.New("failed to send error message") - errDuplicatedNonce = errors.New("duplicated Nonce generated, discarding request") errNoSuchUser = errors.New("no such user exists") errUnexpectedClass = errors.New("unexpected class") errUnexpectedMethod = errors.New("unexpected method") diff --git a/internal/server/nonce.go b/internal/server/nonce.go new file mode 100644 index 00000000..b3f3131e --- /dev/null +++ b/internal/server/nonce.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "time" +) + +const ( + nonceLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-4 + nonceLength = 40 + nonceKeyLength = 64 +) + +// NewNonceHash creates a NonceHash +func NewNonceHash() (*NonceHash, error) { + key := make([]byte, nonceKeyLength) + if _, err := rand.Read(key); err != nil { + return nil, err + } + + return &NonceHash{key}, nil +} + +// NonceHash is used to create and verify nonces +type NonceHash struct { + key []byte +} + +// Generate a nonce +func (n *NonceHash) Generate() (string, error) { + nonce := make([]byte, 8, nonceLength) + binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli())) + + hash := hmac.New(sha256.New, n.key) + if _, err := hash.Write(nonce[:8]); err != nil { + return "", fmt.Errorf("%w: %v", errFailedToGenerateNonce, err) //nolint:errorlint + } + nonce = hash.Sum(nonce) + + return hex.EncodeToString(nonce), nil +} + +// Validate checks that nonce is signed and is not expired +func (n *NonceHash) Validate(nonce string) error { + b, err := hex.DecodeString(nonce) + if err != nil || len(b) != nonceLength { + return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint + } + + if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime { + return errInvalidNonce + } + + hash := hmac.New(sha256.New, n.key) + if _, err = hash.Write(b[:8]); err != nil { + return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint + } + if !hmac.Equal(b[8:], hash.Sum(nil)) { + return errInvalidNonce + } + + return nil +} diff --git a/internal/server/nonce_test.go b/internal/server/nonce_test.go new file mode 100644 index 00000000..1b92de32 --- /dev/null +++ b/internal/server/nonce_test.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNonceHash(t *testing.T) { + t.Run("generated hashes validate", func(t *testing.T) { + h, err := NewNonceHash() + assert.NoError(t, err) + nonce, err := h.Generate() + assert.NoError(t, err) + assert.NoError(t, h.Validate(nonce)) + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 9dadf133..ae2dcec5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,7 +7,6 @@ package server import ( "fmt" "net" - "sync" "time" "github.com/pion/logging" @@ -25,7 +24,7 @@ type Request struct { // Server State AllocationManager *allocation.Manager - Nonces *sync.Map + NonceHash *NonceHash // User Configuration AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index 3e419407..b188facb 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -8,7 +8,6 @@ package server import ( "net" - "sync" "testing" "time" @@ -80,18 +79,21 @@ func TestAllocationLifeTime(t *testing.T) { }) assert.NoError(t, err) - staticKey := []byte("ABC") + nonceHash, err := NewNonceHash() + assert.NoError(t, err) + staticKey, err := nonceHash.Generate() + assert.NoError(t, err) + r := Request{ AllocationManager: allocationManager, - Nonces: &sync.Map{}, + NonceHash: nonceHash, Conn: l, SrcAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5000}, Log: logger, AuthHandler: func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) { - return staticKey, true + return []byte(staticKey), true }, } - r.Nonces.Store(string(staticKey), time.Now()) fiveTuple := &allocation.FiveTuple{SrcAddr: r.SrcAddr, DstAddr: r.Conn.LocalAddr(), Protocol: allocation.UDP} diff --git a/internal/server/util.go b/internal/server/util.go index d11f5291..d34d7b15 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -4,14 +4,9 @@ package server import ( - "crypto/md5" //nolint:gosec,gci - "crypto/rand" "errors" "fmt" - "io" - "math/big" "net" - "strconv" "time" "github.com/pion/stun/v2" @@ -20,29 +15,8 @@ import ( const ( maximumAllocationLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-6.2 defines 3600 seconds recommendation - nonceLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-4 ) -func buildNonce() (string, error) { - /* #nosec */ - h := md5.New() - if _, err := io.WriteString(h, strconv.FormatInt(time.Now().Unix(), 10)); err != nil { - return "", fmt.Errorf("%w: %v", errFailedToGenerateNonce, err) //nolint:errorlint - } - - maxInt63 := big.NewInt(1<<63 - 1) - maxInt63.Add(maxInt63, big.NewInt(1)) - randInt63, err := rand.Int(rand.Reader, maxInt63) - if err != nil { - return "", fmt.Errorf("%w: %v", errFailedToGenerateNonce, err) //nolint:errorlint - } - - if _, err := io.WriteString(h, randInt63.String()); err != nil { //nolint:gosec - return "", fmt.Errorf("%w: %v", errFailedToGenerateNonce, err) //nolint:errorlint - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} - func buildAndSend(conn net.PacketConn, dst net.Addr, attrs ...stun.Setter) error { msg, err := stun.Build(attrs...) if err != nil { @@ -70,16 +44,11 @@ func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageTy func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (stun.MessageIntegrity, bool, error) { respondWithNonce := func(responseCode stun.ErrorCode) (stun.MessageIntegrity, bool, error) { - nonce, err := buildNonce() + nonce, err := r.NonceHash.Generate() if err != nil { return nil, false, err } - // Nonce has already been taken - if _, keyCollision := r.Nonces.LoadOrStore(nonce, time.Now()); keyCollision { - return nil, false, errDuplicatedNonce - } - return nil, false, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(callingMethod, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: responseCode}, @@ -101,15 +70,8 @@ func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } - // Assert Nonce exists and is not expired - nonceCreationTime, nonceFound := r.Nonces.Load(string(*nonceAttr)) - if !nonceFound { - r.Nonces.Delete(nonceAttr) - return respondWithNonce(stun.CodeStaleNonce) - } - - if timeValue, ok := nonceCreationTime.(time.Time); !ok || time.Since(timeValue) >= nonceLifetime { - r.Nonces.Delete(nonceAttr) + // Assert Nonce is signed and is not expired + if err := r.NonceHash.Validate(nonceAttr.String()); err != nil { return respondWithNonce(stun.CodeStaleNonce) } diff --git a/server.go b/server.go index 7da6ba72..db1c1852 100644 --- a/server.go +++ b/server.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net" - "sync" "time" "github.com/pion/logging" @@ -27,7 +26,7 @@ type Server struct { authHandler AuthHandler realm string channelBindTimeout time.Duration - nonces *sync.Map + nonceHash *server.NonceHash packetConnConfigs []PacketConnConfig listenerConfigs []ListenerConfig @@ -53,6 +52,11 @@ func NewServer(config ServerConfig) (*Server, error) { mtu = config.InboundMTU } + nonceHash, err := server.NewNonceHash() + if err != nil { + return nil, err + } + s := &Server{ log: loggerFactory.NewLogger("turn"), authHandler: config.AuthHandler, @@ -60,7 +64,7 @@ func NewServer(config ServerConfig) (*Server, error) { channelBindTimeout: config.ChannelBindTimeout, packetConnConfigs: config.PacketConnConfigs, listenerConfigs: config.ListenerConfigs, - nonces: &sync.Map{}, + nonceHash: nonceHash, inboundMTU: mtu, } @@ -205,7 +209,7 @@ func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manage Realm: s.realm, AllocationManager: allocationManager, ChannelBindTimeout: s.channelBindTimeout, - Nonces: s.nonces, + NonceHash: s.nonceHash, }); err != nil { s.log.Errorf("Failed to handle datagram: %v", err) }