From 7a3880b814408fbd12d87cb3d41780064d884e70 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Tue, 18 Jun 2024 19:00:20 -0400 Subject: [PATCH] Fire OnTrack before reading first RTP Prior to this, we would wait for a single RTP packet to figure out the codec which is not to spec. --- .gitignore | 1 + peerconnection.go | 16 +---------- peerconnection_go_test.go | 45 +++++++++++------------------- peerconnection_media_test.go | 53 ++++++++++++++++++++++++++---------- track_local_static_test.go | 4 +++ track_remote.go | 41 +--------------------------- 6 files changed, 61 insertions(+), 99 deletions(-) diff --git a/.gitignore b/.gitignore index 6e2f206a9f6..b7f3da3e91c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ cover.out examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js +*.DS_Store diff --git a/peerconnection.go b/peerconnection.go index ef370664594..1c5db72c81e 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1236,21 +1236,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece return } - go func(track *TrackRemote) { - b := make([]byte, pc.api.settingEngine.getReceiveMTU()) - n, _, err := track.peek(b) - if err != nil { - pc.log.Warnf("Could not determine PayloadType for SSRC %d (%s)", track.SSRC(), err) - return - } - - if err = track.checkAndUpdateTrack(b[:n]); err != nil { - pc.log.Warnf("Failed to set codec settings for track SSRC %d (%s)", track.SSRC(), err) - return - } - - pc.onTrack(track, receiver) - }(t) + pc.onTrack(t, receiver) } } diff --git a/peerconnection_go_test.go b/peerconnection_go_test.go index 2dde30694a9..a7141681877 100644 --- a/peerconnection_go_test.go +++ b/peerconnection_go_test.go @@ -23,7 +23,6 @@ import ( "time" "github.com/pion/ice/v3" - "github.com/pion/rtp" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/pion/webrtc/v4/internal/util" @@ -1000,9 +999,11 @@ func TestICERestart_Error_Handling(t *testing.T) { } type trackRecords struct { - mu sync.Mutex - trackIDs map[string]struct{} - receivedTrackIDs map[string]struct{} + mu sync.Mutex + trackIDs map[string]struct{} + receivedTrackIDs map[string]struct{} + onAllTracksReceived chan struct{} + onAllTracksReceivedOnce sync.Once } func (r *trackRecords) newTrack() (*TrackLocalStaticRTP, error) { @@ -1019,6 +1020,11 @@ func (r *trackRecords) handleTrack(t *TrackRemote, _ *RTPReceiver) { if _, exist := r.trackIDs[tID]; exist { r.receivedTrackIDs[tID] = struct{}{} } + if len(r.receivedTrackIDs) == len(r.trackIDs) { + r.onAllTracksReceivedOnce.Do(func() { + close(r.onAllTracksReceived) + }) + } } func (r *trackRecords) remains() int { @@ -1032,32 +1038,17 @@ func TestPeerConnection_MassiveTracks(t *testing.T) { var ( api = NewAPI() tRecs = &trackRecords{ - trackIDs: make(map[string]struct{}), - receivedTrackIDs: make(map[string]struct{}), + trackIDs: make(map[string]struct{}), + receivedTrackIDs: make(map[string]struct{}), + onAllTracksReceived: make(chan struct{}), } tracks = []*TrackLocalStaticRTP{} trackCount = 256 pingInterval = 1 * time.Second noiseInterval = 100 * time.Microsecond timeoutDuration = 20 * time.Second - rawPkt = []byte{ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, - 0x27, 0x82, 0x00, 0x01, 0x00, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x98, 0x36, 0xbe, 0x88, 0x9e, - } - samplePkt = &rtp.Packet{ - Header: rtp.Header{ - Marker: true, - Extension: false, - ExtensionProfile: 1, - Version: 2, - SequenceNumber: 27023, - Timestamp: 3653407706, - CSRC: []uint32{}, - }, - Payload: rawPkt[20:], - } - connected = make(chan struct{}) - stopped = make(chan struct{}) + connected = make(chan struct{}) + stopped = make(chan struct{}) ) assert.NoError(t, api.mediaEngine.RegisterDefaultCodecs()) offerPC, answerPC, err := api.newPair(Configuration{}) @@ -1090,12 +1081,8 @@ func TestPeerConnection_MassiveTracks(t *testing.T) { } }() assert.NoError(t, signalPair(offerPC, answerPC)) - // Send a RTP packets to each track to trigger track event after connected. <-connected - time.Sleep(1 * time.Second) - for _, track := range tracks { - assert.NoError(t, track.WriteRTP(samplePkt)) - } + // Ping trackRecords to see if any track event not received yet. tooLong := time.After(timeoutDuration) for { diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 1a56dfb4bb7..9a968c1a70a 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -20,6 +20,8 @@ import ( "testing" "time" + "github.com/pion/interceptor" + mock_interceptor "github.com/pion/interceptor/pkg/mock" "github.com/pion/logging" "github.com/pion/randutil" "github.com/pion/rtcp" @@ -1045,11 +1047,33 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) { m := &MediaEngine{} assert.NoError(t, m.RegisterDefaultCodecs()) + ir := &interceptor.Registry{} + + trackReadDone := make(chan struct{}) + ir.Add(&mock_interceptor.Factory{ + NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { + return &mock_interceptor.Interceptor{ + BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + count := int64(0) + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + if a == nil { + a = interceptor.Attributes{} + } + if atomic.AddInt64(&count, 1) > 5 { + // confirm read before sending any more packets for probing + <-trackReadDone + } + return reader.Read(b, a) + }) + }, + }, nil + }, + }) assert.NoError(t, ConfigureSimulcastExtensionHeaders(m)) pcOffer, pcAnswer, err := NewAPI(WithSettingEngine(SettingEngine{ LoggerFactory: &undeclaredSsrcLoggerFactory{unhandledSimulcastError}, - }), WithMediaEngine(m)).newPair(Configuration{}) + }), WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{}) assert.NoError(t, err) firstTrack, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "firstTrack", "firstTrack") @@ -1093,26 +1117,24 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) { time.Sleep(20 * time.Millisecond) } + // establish undeclared SSRC (half number of probes) for ; sequenceNumber <= 5; sequenceNumber++ { sendRTPPacket() } - assert.NoError(t, signalPair(pcOffer, pcAnswer)) - trackRemoteChan := make(chan *TrackRemote, 1) - pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) { + pcAnswer.OnTrack(func(trackRemote *TrackRemote, recv *RTPReceiver) { trackRemoteChan <- trackRemote }) - trackRemote := func() *TrackRemote { - for { - select { - case t := <-trackRemoteChan: - return t - default: - sendRTPPacket() - } - } + assert.NoError(t, signalPair(pcOffer, pcAnswer)) + + trackRemote := <-trackRemoteChan + + go func() { + _, _, err = trackRemote.Read(make([]byte, 1500)) + assert.NoError(t, err) + close(trackReadDone) }() func() { @@ -1126,8 +1148,7 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) { } }() - _, _, err = trackRemote.Read(make([]byte, 1500)) - assert.NoError(t, err) + <-trackReadDone closePairNow(t, pcOffer, pcAnswer) }) @@ -1755,6 +1776,8 @@ func TestPeerConnection_Zero_PayloadType(t *testing.T) { trackFired := make(chan struct{}) pcAnswer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + _, _, err = track.Read(make([]byte, 1500)) + assert.NoError(t, err) require.Equal(t, track.Codec().MimeType, MimeTypePCMU) close(trackFired) }) diff --git a/track_local_static_test.go b/track_local_static_test.go index f4d9dfd92fd..25ca285620d 100644 --- a/track_local_static_test.go +++ b/track_local_static_test.go @@ -150,6 +150,8 @@ func Test_TrackLocalStatic_PayloadType(t *testing.T) { onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background()) offerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + _, _, err = track.Read(make([]byte, 1500)) + assert.NoError(t, err) assert.Equal(t, track.PayloadType(), PayloadType(100)) assert.Equal(t, track.Codec().RTPCodecCapability.MimeType, "video/VP8") @@ -284,6 +286,8 @@ func Test_TrackLocalStatic_Padding(t *testing.T) { onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background()) offerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + _, _, err = track.Read(make([]byte, 1500)) + assert.NoError(t, err) assert.Equal(t, track.PayloadType(), PayloadType(100)) assert.Equal(t, track.Codec().RTPCodecCapability.MimeType, "video/VP8") diff --git a/track_remote.go b/track_remote.go index 7e448dd9895..2fd7f5f5eca 100644 --- a/track_remote.go +++ b/track_remote.go @@ -29,9 +29,7 @@ type TrackRemote struct { params RTPParameters rid string - receiver *RTPReceiver - peeked []byte - peekedAttributes interceptor.Attributes + receiver *RTPReceiver } func newTrackRemote(kind RTPCodecType, ssrc, rtxSsrc SSRC, rid string, receiver *RTPReceiver) *TrackRemote { @@ -107,26 +105,8 @@ func (t *TrackRemote) Codec() RTPCodecParameters { func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) { t.mu.RLock() r := t.receiver - peeked := t.peeked != nil t.mu.RUnlock() - if peeked { - t.mu.Lock() - data := t.peeked - attributes = t.peekedAttributes - - t.peeked = nil - t.peekedAttributes = nil - t.mu.Unlock() - // someone else may have stolen our packet when we - // released the lock. Deal with it. - if data != nil { - n = copy(b, data) - err = t.checkAndUpdateTrack(b) - return - } - } - // If there's a separate RTX track and an RTX packet is available, return that if rtxPacketReceived := r.readRTX(t); rtxPacketReceived != nil { n = copy(b, rtxPacketReceived.pkt) @@ -187,25 +167,6 @@ func (t *TrackRemote) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) { return r, attributes, nil } -// peek is like Read, but it doesn't discard the packet read -func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error) { - n, a, err = t.Read(b) - if err != nil { - return - } - - t.mu.Lock() - // this might overwrite data if somebody peeked between the Read - // and us getting the lock. Oh well, we'll just drop a packet in - // that case. - data := make([]byte, n) - n = copy(data, b[:n]) - t.peeked = data - t.peekedAttributes = a - t.mu.Unlock() - return -} - // SetReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever. func (t *TrackRemote) SetReadDeadline(deadline time.Time) error { return t.receiver.setRTPReadDeadline(deadline, t)