diff --git a/gather.go b/gather.go index bbe416d0..2c7b7489 100644 --- a/gather.go +++ b/gather.go @@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { return errUDPMuxDisabled } - localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4}) + localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes) switch { case err != nil: return err @@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { } } - conn, err := a.udpMux.GetConn(a.localUfrag) + conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil) if err != nil { return err } @@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne return } - conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String()) + conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), networkType.IsIPv6()) if err != nil { a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err) return diff --git a/gather_test.go b/gather_test.go index 23c8296a..0e3ef243 100644 --- a/gather_test.go +++ b/gather_test.go @@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du return nil, errNotImplemented } -func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { +func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() m.getConnForURLTimes++ diff --git a/udp_mux.go b/udp_mux.go index 0bbe3b5c..8149d283 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -14,7 +14,7 @@ import ( // UDPMux allows multiple connections to go over a single UDP port type UDPMux interface { io.Closer - GetConn(ufrag string) (net.PacketConn, error) + GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) } @@ -25,8 +25,8 @@ type UDPMuxDefault struct { closedChan chan struct{} closeOnce sync.Once - // conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType - conns map[string]*udpMuxedConn + // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType + connsIPv4, connsIPv6 map[string]*udpMuxedConn addressMapMu sync.RWMutex addressMap map[string]*udpMuxedConn @@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { m := &UDPMuxDefault{ addressMap: map[string]*udpMuxedConn{}, params: params, - conns: make(map[string]*udpMuxedConn), + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { @@ -76,7 +77,7 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network // creates the connection if an existing one can't be found -func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { +func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() @@ -84,8 +85,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { return nil, io.ErrClosedPipe } - if c, ok := m.conns[ufrag]; ok { - return c, nil + if conn, ok := m.getConn(ufrag, isIPv6); ok { + return conn, nil } c := m.createMuxedConn(ufrag) @@ -93,26 +94,30 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { <-c.CloseChannel() m.removeConn(ufrag) }() - m.conns[ufrag] = c + + if isIPv6 { + m.connsIPv6[ufrag] = c + } else { + m.connsIPv4[ufrag] = c + } + return c, nil } // RemoveConnByUfrag stops and removes the muxed packet connection func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { - m.mu.Lock() - removedConns := make([]*udpMuxedConn, 0) - for key := range m.conns { - if key != ufrag { - continue - } + removedConns := make([]*udpMuxedConn, 0, 2) - c := m.conns[key] - delete(m.conns, key) - if c != nil { - removedConns = append(removedConns, c) - } + // Keep lock section small to avoid deadlock with conn lock + m.mu.Lock() + if c, ok := m.connsIPv4[ufrag]; ok { + delete(m.connsIPv4, ufrag) + removedConns = append(removedConns, c) + } + if c, ok := m.connsIPv6[ufrag]; ok { + delete(m.connsIPv6, ufrag) + removedConns = append(removedConns, c) } - // keep lock section small to avoid deadlock with conn lock m.mu.Unlock() m.addressMapMu.Lock() @@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error { m.mu.Lock() defer m.mu.Unlock() - for _, c := range m.conns { + for _, c := range m.connsIPv4 { _ = c.Close() } - m.conns = make(map[string]*udpMuxedConn) + for _, c := range m.connsIPv6 { + _ = c.Close() + } + + m.connsIPv4 = make(map[string]*udpMuxedConn) + m.connsIPv6 = make(map[string]*udpMuxedConn) + close(m.closedChan) }) return err } func (m *UDPMuxDefault) removeConn(key string) { - m.mu.Lock() - c := m.conns[key] - delete(m.conns, key) // keep lock section small to avoid deadlock with conn lock - m.mu.Unlock() + c := func() *udpMuxedConn { + m.mu.Lock() + defer m.mu.Unlock() + + if c, ok := m.connsIPv4[key]; ok { + delete(m.connsIPv4, key) + return c + } + + if c, ok := m.connsIPv6[key]; ok { + delete(m.connsIPv6, key) + return c + } + + return nil + }() if c == nil { return @@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() { } ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := udpAddr.IP.To4() == nil m.mu.Lock() - destinationConn = m.conns[ufrag] + destinationConn, _ = m.getConn(ufrag, isIPv6) m.mu.Unlock() } @@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() { } } +func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { + if isIPv6 { + val, ok = m.connsIPv6[ufrag] + } else { + val, ok = m.connsIPv4[ufrag] + } + return +} + type bufferHolder struct { buffer []byte } diff --git a/udp_mux_test.go b/udp_mux_test.go index 3dd47f9a..127f25f9 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) { require.NoError(t, udpMux.Close()) // can't create more connections - _, err = udpMux.GetConn("failufrag") + _, err = udpMux.GetConn("failufrag", false) require.Error(t, err) } @@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) { } func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { - pktConn, err := udpMux.GetConn(ufrag) + pktConn, err := udpMux.GetConn(ufrag, false) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() diff --git a/udp_mux_universal.go b/udp_mux_universal.go index c6198775..d6dc57da 100644 --- a/udp_mux_universal.go +++ b/udp_mux_universal.go @@ -16,7 +16,7 @@ type UniversalUDPMux interface { UDPMux GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) - GetConnForURL(ufrag string, url string) (net.PacketConn, error) + GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) } // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. @@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { - return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6) } // ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. diff --git a/udp_mux_universal_test.go b/udp_mux_universal_test.go index c263bf61..052eedbf 100644 --- a/udp_mux_universal_test.go +++ b/udp_mux_universal_test.go @@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) { } func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { - pktConn, err := udpMux.GetConn(ufrag) + pktConn, err := udpMux.GetConn(ufrag, false) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close()