diff --git a/interceptor_test.go b/interceptor_test.go index aff8b67573c..0dc0eca4b24 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -9,12 +9,14 @@ package webrtc // import ( "context" + "io" "sync/atomic" "testing" "time" "github.com/pion/interceptor" mock_interceptor "github.com/pion/interceptor/pkg/mock" + "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/transport/v3/test" "github.com/pion/webrtc/v4/pkg/media" @@ -284,3 +286,146 @@ func Test_Interceptor_ZeroSSRC(t *testing.T) { <-probeReceiverCreated closePairNow(t, offerer, answerer) } + +// TestInterceptorNack is an end-to-end test for the NACK sender. +// It test that: +// - we get a NACK if we negotiated generic NACks; +// - we don't get a NACK if we did not negotiate generick NACKs; +// - the NACK corresponds to the missing packet. +func TestInterceptorNack(t *testing.T) { + const numPackets = 20 + to := test.TimeOut(time.Second * 20) + defer to.Stop() + + t.Run("Nack", func(t *testing.T) { testInterceptorNack(t, true) }) + t.Run("NoNack", func(t *testing.T) { testInterceptorNack(t, false) }) +} + +func testInterceptorNack(t *testing.T, requestNack bool) { + ir := interceptor.Registry{} + m := MediaEngine{} + var capability []RTCPFeedback + if requestNack { + capability = append(capability, RTCPFeedback{"nack", ""}) + } + err := m.RegisterCodec( + RTPCodecParameters{ + RTPCodecCapability: RTPCodecCapability{ + "video/VP8", 90000, 0, + "", + capability, + }, + PayloadType: 96, + }, + RTPCodecTypeVideo, + ) + assert.NoError(t, err) + api := NewAPI( + WithMediaEngine(&m), + WithInterceptorRegistry(&ir), + ) + + pc1, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + defer pc1.Close() + + track1, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "video", "pion", + ) + assert.NoError(t, err) + sender, err := pc1.AddTrack(track1) + assert.NoError(t, err) + + pc2, err := NewPeerConnection(Configuration{}) + assert.NoError(t, err) + defer pc2.Close() + + offer, err := pc1.CreateOffer(nil) + assert.NoError(t, err) + err = pc1.SetLocalDescription(offer) + assert.NoError(t, err) + <-GatheringCompletePromise(pc1) + + err = pc2.SetRemoteDescription(*pc1.LocalDescription()) + assert.NoError(t, err) + answer, err := pc2.CreateAnswer(nil) + assert.NoError(t, err) + err = pc2.SetLocalDescription(answer) + assert.NoError(t, err) + <-GatheringCompletePromise(pc2) + + err = pc1.SetRemoteDescription(*pc2.LocalDescription()) + assert.NoError(t, err) + + gotNack := false + go func() { + buf := make([]byte, 1500) + for { + n, _, err := sender.Read(buf) + if err == io.EOF { + break + } + assert.NoError(t, err) + ps, err := rtcp.Unmarshal(buf[:n]) + assert.NoError(t, err) + for _, p := range ps { + if pn, ok := p.(*rtcp.TransportLayerNack); ok { + assert.Equal(t, len(pn.Nacks), 1) + assert.Equal(t, + pn.Nacks[0].PacketID, uint16(1), + ) + assert.Equal(t, + pn.Nacks[0].LostPackets, + rtcp.PacketBitmap(0), + ) + gotNack = true + } + } + } + }() + + const numPackets = 20 + done := make(chan struct{}) + pc2.OnTrack(func(track2 *TrackRemote, receiver *RTPReceiver) { + for i := 0; i < numPackets; i++ { + if i == 1 { + continue + } + p, _, err := track2.ReadRTP() + assert.NoError(t, err) + assert.Equal(t, p.SequenceNumber, uint16(i)) + } + done <- struct{}{} + }) + + go func() { + for i := 0; i < numPackets; i++ { + time.Sleep(20 * time.Millisecond) + if i == 1 { + continue + } + var p rtp.Packet + p.Version = 2 + p.Marker = true + p.PayloadType = 96 + p.SequenceNumber = uint16(i) + p.Timestamp = uint32(i * 90000 / 50) + p.Payload = []byte{42} + err := track1.WriteRTP(&p) + assert.NoError(t, err) + } + }() + + <-done + + if requestNack { + if !gotNack { + t.Errorf("Expected to get a NACK, got none") + } + } else { + if gotNack { + t.Errorf("Expected to get no NACK, got one") + } + } +}