Skip to content

Commit

Permalink
Bug fix for invalid channel addr
Browse files Browse the repository at this point in the history
Change to connect's source address.
  • Loading branch information
songjiayang authored and songjiayang committed Jul 11, 2019
1 parent ebf4cb7 commit bb3bfbf
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 52 deletions.
26 changes: 13 additions & 13 deletions internal/allocation/allocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ func (a *Allocation) RemovePermission(addr net.Addr) {
// permissions needed for this ChannelBind
func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) error {
// Check that this channel id isn't bound to another transport address, and
// that this transport address isn't bound to another channel id.
channelByID := a.GetChannelByID(c.ID)
// that this transport address isn't bound to another channel number.
channelByNumber := a.GetChannelByNumber(c.Number)
channelByPeer := a.GetChannelByAddr(c.Peer)
if channelByID != channelByPeer {
if channelByNumber != channelByPeer {
return errors.Errorf("You cannot use the same channel number with different peer")
}

// Add or refresh this channel.
if channelByID == nil {
if channelByNumber == nil {
a.channelBindingsLock.Lock()
defer a.channelBindingsLock.Unlock()

Expand All @@ -111,22 +111,22 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro
// Channel binds also refresh permissions.
a.AddPermission(NewPermission(c.Peer, a.log))
} else {
channelByID.refresh(lifetime)
channelByNumber.refresh(lifetime)

// Channel binds also refresh permissions.
a.AddPermission(NewPermission(channelByID.Peer, a.log))
a.AddPermission(NewPermission(channelByNumber.Peer, a.log))
}

return nil
}

// RemoveChannelBind removes the ChannelBind from this allocation by id
func (a *Allocation) RemoveChannelBind(id turn.ChannelNumber) bool {
func (a *Allocation) RemoveChannelBind(number turn.ChannelNumber) bool {
a.channelBindingsLock.Lock()
defer a.channelBindingsLock.Unlock()

for i := len(a.channelBindings) - 1; i >= 0; i-- {
if a.channelBindings[i].ID == id {
if a.channelBindings[i].Number == number {
a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...)
return true
}
Expand All @@ -135,12 +135,12 @@ func (a *Allocation) RemoveChannelBind(id turn.ChannelNumber) bool {
return false
}

// GetChannelByID gets the ChannelBind from this allocation by id
func (a *Allocation) GetChannelByID(id turn.ChannelNumber) *ChannelBind {
// GetChannelByNumber gets the ChannelBind from this allocation by id
func (a *Allocation) GetChannelByNumber(number turn.ChannelNumber) *ChannelBind {
a.channelBindingsLock.RLock()
defer a.channelBindingsLock.RUnlock()
for _, cb := range a.channelBindings {
if cb.ID == id {
if cb.Number == number {
return cb
}
}
Expand Down Expand Up @@ -235,9 +235,9 @@ func (a *Allocation) packetHandler(m *Manager) {
n,
srcAddr.String())

if channel := a.GetChannelByAddr(a.RelaySocket.LocalAddr()); channel != nil {
if channel := a.GetChannelByAddr(srcAddr); channel != nil {
channelData := make([]byte, 4)
binary.BigEndian.PutUint16(channelData[0:], uint16(channel.ID))
binary.BigEndian.PutUint16(channelData[0:], uint16(channel.Number))
binary.BigEndian.PutUint16(channelData[2:], uint16(n))
channelData = append(channelData, buffer[:n]...)

Expand Down
233 changes: 213 additions & 20 deletions internal/allocation/allocation_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package allocation

import (
"fmt"
"net"
"testing"
"time"

"github.com/gortc/turn"
"github.com/pion/stun"
"github.com/stretchr/testify/assert"

"github.com/pion/turn/internal/ipnet"
)

func TestAllocation(t *testing.T) {
Expand All @@ -13,6 +21,12 @@ func TestAllocation(t *testing.T) {
{"GetPermission", subTestGetPermission},
{"AddPermission", subTestAddPermission},
{"RemovePermission", subTestRemovePermission},
{"AddChannelBind", subTestAddChannelBind},
{"GetChannelByNumber", subTestGetChannelByNumber},
{"GetChannelByAddr", subTestGetChannelByAddr},
{"RemoveChannelBind", subTestRemoveChannelBind},
{"Close", subTestAllocationClose},
{"packetHandler", subTestPacketHandler},
}

for _, tc := range tt {
Expand Down Expand Up @@ -45,19 +59,13 @@ func subTestGetPermission(t *testing.T) {
a.AddPermission(p3)

foundP1 := a.GetPermission(addr)
if foundP1 != p {
t.Error("Should keep the first one.")
}
assert.Equal(t, p, foundP1, "Should keep the first one.")

foundP2 := a.GetPermission(addr2)
if foundP2 != p {
t.Error("Second one should be ignored.")
}
assert.Equal(t, p, foundP2, "Second one should be ignored.")

foundP3 := a.GetPermission(addr3)
if foundP3 != p3 {
t.Error("Permission with another IP should be found")
}
assert.Equal(t, p3, foundP3, "Permission with another IP should be found")
}

func subTestAddPermission(t *testing.T) {
Expand All @@ -69,14 +77,11 @@ func subTestAddPermission(t *testing.T) {
}

a.AddPermission(p)
if p.allocation != a {
t.Error("Permission's allocation should be the adder.")
}
assert.Equal(t, a, p.allocation, "Permission's allocation should be the adder.")

foundPermission := a.GetPermission(p.Addr)
if foundPermission != p {
t.Error("Got permission is not same as the the added.")
}
assert.Equal(t, p, foundPermission)

}

func subTestRemovePermission(t *testing.T) {
Expand All @@ -90,14 +95,202 @@ func subTestRemovePermission(t *testing.T) {
a.AddPermission(p)

foundPermission := a.GetPermission(p.Addr)
if foundPermission != p {
t.Error("Got permission is not same as the the added.")
}
assert.Equal(t, p, foundPermission, "Got permission is not same as the the added.")

a.RemovePermission(p.Addr)

foundPermission = a.GetPermission(p.Addr)
if foundPermission != nil {
t.Error("Got permission should be nil after removed.")
assert.Nil(t, foundPermission, "Got permission should be nil after removed.")
}

func subTestAddChannelBind(t *testing.T) {
a := NewAllocation(nil, nil, nil)

addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3478")
c := NewChannelBind(turn.MinChannelNumber, addr, nil)

err := a.AddChannelBind(c, turn.DefaultLifetime)
assert.Nil(t, err, "should succeed")
assert.Equal(t, a, c.allocation, "allocation should be the caller.")

c2 := NewChannelBind(turn.MinChannelNumber+1, addr, nil)
err = a.AddChannelBind(c2, turn.DefaultLifetime)
assert.NotNil(t, err, "should failed with conflicted peer address")

addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3479")
c3 := NewChannelBind(turn.MinChannelNumber, addr2, nil)
err = a.AddChannelBind(c3, turn.DefaultLifetime)
assert.NotNil(t, err, "should fail with conflicted number.")
}

func subTestGetChannelByNumber(t *testing.T) {
a := NewAllocation(nil, nil, nil)

addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3478")
c := NewChannelBind(turn.MinChannelNumber, addr, nil)

_ = a.AddChannelBind(c, turn.DefaultLifetime)

existChannel := a.GetChannelByNumber(c.Number)
assert.Equal(t, c, existChannel)

notExistChannel := a.GetChannelByNumber(turn.MinChannelNumber + 1)
assert.Nil(t, notExistChannel, "should be nil for not existed channel.")
}

func subTestGetChannelByAddr(t *testing.T) {
a := NewAllocation(nil, nil, nil)

addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3478")
c := NewChannelBind(turn.MinChannelNumber, addr, nil)

_ = a.AddChannelBind(c, turn.DefaultLifetime)

existChannel := a.GetChannelByAddr(c.Peer)
assert.Equal(t, c, existChannel)

addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3479")
notExistChannel := a.GetChannelByAddr(addr2)
assert.Nil(t, notExistChannel, "should be nil for not existed channel.")
}

func subTestRemoveChannelBind(t *testing.T) {
a := NewAllocation(nil, nil, nil)

addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3478")
c := NewChannelBind(turn.MinChannelNumber, addr, nil)

_ = a.AddChannelBind(c, turn.DefaultLifetime)

a.RemoveChannelBind(c.Number)

channelByNumber := a.GetChannelByNumber(c.Number)
assert.Nil(t, channelByNumber)

channelByAddr := a.GetChannelByAddr(c.Peer)
assert.Nil(t, channelByAddr)
}

func subTestAllocationClose(t *testing.T) {
network := "udp"

l, err := net.ListenPacket(network, "0.0.0.0:0")
if err != nil {
panic(err)
}

a := NewAllocation(nil, nil, nil)
a.RelaySocket = l
// add mock lifetimeTimer
a.lifetimeTimer = time.AfterFunc(turn.DefaultLifetime, func() {})

// add channel
addr, _ := net.ResolveUDPAddr(network, "127.0.0.1:3478")
c := NewChannelBind(turn.MinChannelNumber, addr, nil)
_ = a.AddChannelBind(c, turn.DefaultLifetime)

// add permission
a.AddPermission(NewPermission(addr, nil))

err = a.Close()
assert.Nil(t, err, "should succeed")
assert.True(t, isClose(a.RelaySocket), "should be closed")
}

func subTestPacketHandler(t *testing.T) {
network := "udp"

m := newTestManager()

// turn server initialization
turnSocket, err := net.ListenPacket(network, "127.0.0.1:0")
if err != nil {
panic(err)
}

// client listener initialization
clientListener, err := net.ListenPacket(network, "127.0.0.1:0")
if err != nil {
panic(err)
}

dataCh := make(chan []byte)
// client listener read data
go func() {
buffer := make([]byte, rtpMTU)
for {
n, _, err2 := clientListener.ReadFrom(buffer)
if err2 != nil {
return
}

dataCh <- buffer[:n]
}
}()

a, err := m.CreateAllocation(&FiveTuple{
SrcAddr: clientListener.LocalAddr(),
DstAddr: turnSocket.LocalAddr(),
}, turnSocket, 0, turn.DefaultLifetime)

assert.Nil(t, err, "should succeed")

peerListener1, err := net.ListenPacket(network, "127.0.0.1:0")
if err != nil {
panic(err)
}

peerListener2, err := net.ListenPacket(network, "127.0.0.1:0")
if err != nil {
panic(err)
}

// add permission with peer1 address
a.AddPermission(NewPermission(peerListener1.LocalAddr(), m.log))
// add channel with min channel number and peer2 address
channelBind := NewChannelBind(turn.MinChannelNumber, peerListener2.LocalAddr(), m.log)
_ = a.AddChannelBind(channelBind, turn.DefaultLifetime)

_, port, _ := ipnet.AddrIPPort(a.RelaySocket.LocalAddr())
relayAddrWithHostStr := fmt.Sprintf("127.0.0.1:%d", port)
relayAddrWithHost, _ := net.ResolveUDPAddr(network, relayAddrWithHostStr)

// test for permission and data message
targetText := "permission"
_, _ = peerListener1.WriteTo([]byte(targetText), relayAddrWithHost)
data := <-dataCh

// resolve stun data message
assert.True(t, stun.IsMessage(data), "should be stun message")

var msg stun.Message
err = stun.Decode(data, &msg)
assert.Nil(t, err, "decode data to stun message failed")

var msgData turn.Data
err = msgData.GetFrom(&msg)
assert.Nil(t, err, "get data from stun message failed")
assert.Equal(t, targetText, string(msgData), "get message doesn't equal the target text")

// test for channel bind and channel data
targetText2 := "channel bind"
_, _ = peerListener2.WriteTo([]byte(targetText2), relayAddrWithHost)
data = <-dataCh

// resolve channel data
assert.True(t, turn.IsChannelData(data), "should be channel data")

channelData := turn.ChannelData{
Raw: data,
}
err = channelData.Decode()
assert.Nil(t, err, fmt.Sprintf("channel data decode with error: %v", err))
assert.Equal(t, channelBind.Number, channelData.Number, "get channel data's number is invalid")
assert.Equal(t, targetText2, string(channelData.Data), "get data doesn't equal the target text.")

// listeners close
_ = m.Close()
_ = clientListener.Close()
_ = peerListener1.Close()
_ = peerListener2.Close()
}
18 changes: 9 additions & 9 deletions internal/allocation/channel_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@ import (
// ChannelBind represents a TURN Channel
// https://tools.ietf.org/html/rfc5766#section-2.5
type ChannelBind struct {
Peer net.Addr
ID turn.ChannelNumber
Peer net.Addr
Number turn.ChannelNumber

allocation *Allocation
lifetimeTimer *time.Timer
log logging.LeveledLogger
}

// NewChannelBind creates a new ChannelBind
func NewChannelBind(id turn.ChannelNumber, peer net.Addr, log logging.LeveledLogger) *ChannelBind {
func NewChannelBind(number turn.ChannelNumber, peer net.Addr, log logging.LeveledLogger) *ChannelBind {
return &ChannelBind{
ID: id,
Peer: peer,
log: log,
Number: number,
Peer: peer,
log: log,
}
}

func (c *ChannelBind) start(lifetime time.Duration) {
c.lifetimeTimer = time.AfterFunc(lifetime, func() {
if !c.allocation.RemoveChannelBind(c.ID) {
c.log.Errorf("Failed to remove ChannelBind for %v %x %v", c.ID, c.Peer, c.allocation.fiveTuple)
if !c.allocation.RemoveChannelBind(c.Number) {
c.log.Errorf("Failed to remove ChannelBind for %v %x %v", c.Number, c.Peer, c.allocation.fiveTuple)
}
})
}

func (c *ChannelBind) refresh(lifetime time.Duration) {
if !c.lifetimeTimer.Reset(lifetime) {
c.log.Errorf("Failed to reset ChannelBind timer for %v %x %v", c.ID, c.Peer, c.allocation.fiveTuple)
c.log.Errorf("Failed to reset ChannelBind timer for %v %x %v", c.Number, c.Peer, c.allocation.fiveTuple)
}
}
Loading

0 comments on commit bb3bfbf

Please sign in to comment.