From f63cfe77cd73b51a4210391664517a4dc1adc413 Mon Sep 17 00:00:00 2001 From: aggresss Date: Tue, 13 Feb 2024 20:45:20 +0800 Subject: [PATCH] Support TrackLocal RTX --- rtpcodec.go | 26 ++++++++++++++ rtpsender.go | 88 ++++++++++++++++++++++++++++++++++++++++++++---- sdp.go | 6 ++++ settingengine.go | 6 ++++ track_local.go | 5 +-- 5 files changed, 123 insertions(+), 8 deletions(-) diff --git a/rtpcodec.go b/rtpcodec.go index 40463dcb0ce..50a234214e2 100644 --- a/rtpcodec.go +++ b/rtpcodec.go @@ -4,6 +4,8 @@ package webrtc import ( + "fmt" + "regexp" "strings" "github.com/pion/webrtc/v4/internal/fmtp" @@ -123,3 +125,27 @@ func codecParametersFuzzySearch(needle RTPCodecParameters, haystack []RTPCodecPa return RTPCodecParameters{}, codecMatchNone } + +// Do a fuzzy find for a associated codec in the list of codecs +// Used for lookup up a associated codec in an existing list to find a match +// Returns codecMatchExact, codecMatchPartial, or codecMatchNone +func codecParametersAssociatedSearch(needle RTPCodecParameters, haystack []RTPCodecParameters) (RTPCodecParameters, codecMatchType) { + + // First attempt to match Exact + for _, c := range haystack { + if c.SDPFmtpLine == fmt.Sprintf("apt=%d", needle.PayloadType) { + return c, codecMatchExact + } + } + + // Fallback to just has apt codec + if re, err := regexp.Compile(`^apt=\d+$`); err == nil { + for _, c := range haystack { + if re.MatchString(c.SDPFmtpLine) { + return c, codecMatchPartial + } + } + } + + return RTPCodecParameters{}, codecMatchNone +} diff --git a/rtpsender.go b/rtpsender.go index 71be3fdb796..bf61015dc08 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -20,16 +20,18 @@ import ( ) type trackEncoding struct { - track TrackLocal - - srtpStream *srtpWriterFuture + track TrackLocal + context *baseTrackLocalContext + ssrc SSRC + srtpStream *srtpWriterFuture rtcpInterceptor interceptor.RTCPReader streamInfo interceptor.StreamInfo - context *baseTrackLocalContext - - ssrc SSRC + rtxSsrc SSRC + rtxSrtpStream *srtpWriterFuture + rtxRtcpInterceptor interceptor.RTCPReader + rtxStreamInfo interceptor.StreamInfo } // RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer @@ -125,6 +127,7 @@ func (r *RTPSender) getParameters() RTPSendParameters { RID: rid, SSRC: trackEncoding.ssrc, PayloadType: r.payloadType, + RTX: RTPRtxParameters{SSRC: trackEncoding.rtxSsrc}, }, }) } @@ -204,6 +207,16 @@ func (r *RTPSender) addEncoding(track TrackLocal) { ssrc: SSRC(randutil.NewMathRandomGenerator().Uint32()), } + if r.api.settingEngine.enableTrackLocalRtx { + codecs := r.api.mediaEngine.getCodecsByKind(track.Kind()) + for _, c := range codecs { + if _, matchType := codecParametersAssociatedSearch(c, codecs); matchType != codecMatchNone { + trackEncoding.rtxSsrc = SSRC(randutil.NewMathRandomGenerator().Uint32()) + break + } + } + } + r.trackEncodings = append(r.trackEncodings, trackEncoding) } @@ -339,6 +352,39 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { ) writeStream.interceptor.Store(rtpInterceptor) + + if rtxCodec, matchType := codecParametersAssociatedSearch(codec, r.api.mediaEngine.getCodecsByKind(r.kind)); matchType == codecMatchExact && + parameters.Encodings[idx].RTX.SSRC != 0 { + + rtxSrtpStream := &srtpWriterFuture{ssrc: parameters.Encodings[idx].RTX.SSRC, rtpSender: r} + + trackEncoding.rtxSrtpStream = rtxSrtpStream + trackEncoding.rtxSsrc = parameters.Encodings[idx].RTX.SSRC + + trackEncoding.rtxStreamInfo = *createStreamInfo( + r.id+"_rtx", + parameters.Encodings[idx].RTX.SSRC, + rtxCodec.PayloadType, + rtxCodec.RTPCodecCapability, + parameters.HeaderExtensions, + ) + trackEncoding.rtxStreamInfo.Attributes.Set("apt_ssrc", uint32(parameters.Encodings[idx].SSRC)) + + trackEncoding.rtxRtcpInterceptor = r.api.interceptor.BindRTCPReader( + interceptor.RTCPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = trackEncoding.rtxSrtpStream.Read(in) + return n, a, err + }), + ) + + r.api.interceptor.BindLocalStream( + &trackEncoding.rtxStreamInfo, + interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + return srtpStream.WriteRTP(header, payload) + }), + ) + + } } close(r.sendCalled) @@ -402,6 +448,36 @@ func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { return pkts, attributes, nil } +// ReadRtx reads incoming RTX Stream RTCP for this RTPSender +func (r *RTPSender) ReadRtx(b []byte) (n int, a interceptor.Attributes, err error) { + if r.trackEncodings[0].rtxRtcpInterceptor == nil { + return 0, nil, io.ErrNoProgress + } + + select { + case <-r.sendCalled: + return r.trackEncodings[0].rtxRtcpInterceptor.Read(b, a) + case <-r.stopCalled: + return 0, nil, io.ErrClosedPipe + } +} + +// ReadRtxRTCP is a convenience method that wraps ReadRtx and unmarshals for you. +func (r *RTPSender) ReadRtxRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { + b := make([]byte, r.api.settingEngine.getReceiveMTU()) + i, attributes, err := r.ReadRtx(b) + if err != nil { + return nil, nil, err + } + + pkts, err := rtcp.Unmarshal(b[:i]) + if err != nil { + return nil, nil, err + } + + return pkts, attributes, nil +} + // ReadSimulcast reads incoming RTCP for this RTPSender for given rid func (r *RTPSender) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) { select { diff --git a/sdp.go b/sdp.go index 49783d91aca..853338b30fd 100644 --- a/sdp.go +++ b/sdp.go @@ -389,7 +389,13 @@ func addSenderSDP( sendParameters := sender.GetParameters() for _, encoding := range sendParameters.Encodings { + if encoding.RTX.SSRC != 0 { + media = media.WithValueAttribute(sdp.AttrKeySSRCGroup, fmt.Sprintf("FID %d %d", encoding.SSRC, encoding.RTX.SSRC)) + } media = media.WithMediaSource(uint32(encoding.SSRC), track.StreamID() /* cname */, track.StreamID() /* streamLabel */, track.ID()) + if encoding.RTX.SSRC != 0 { + media = media.WithMediaSource(uint32(encoding.RTX.SSRC), track.StreamID() /* cname */, track.StreamID() /* streamLabel */, track.ID()) + } if !isPlanB { media = media.WithPropertyAttribute("msid:" + track.StreamID() + " " + track.ID()) } diff --git a/settingengine.go b/settingengine.go index ddf679a3535..cb4f8bda5b4 100644 --- a/settingengine.go +++ b/settingengine.go @@ -92,6 +92,7 @@ type SettingEngine struct { srtpProtectionProfiles []dtls.SRTPProtectionProfile receiveMTU uint iceMaxBindingRequests *uint16 + enableTrackLocalRtx bool } // getReceiveMTU returns the configured MTU. If SettingEngine's MTU is configured to 0 it returns the default @@ -437,3 +438,8 @@ func (e *SettingEngine) SetSCTPMaxReceiveBufferSize(maxReceiveBufferSize uint32) func (e *SettingEngine) SetDTLSCustomerCipherSuites(customCipherSuites func() []dtls.CipherSuite) { e.dtls.customCipherSuites = customCipherSuites } + +// SetEnableTrackLocalRtx allows track local use RTX. +func (e *SettingEngine) SetEnableTrackLocalRtx(enable bool) { + e.enableTrackLocalRtx = enable +} diff --git a/track_local.go b/track_local.go index 21131c81119..930d53cff86 100644 --- a/track_local.go +++ b/track_local.go @@ -44,8 +44,9 @@ type TrackLocalContext interface { } type baseTrackLocalContext struct { - id string - params RTPParameters + id string + params RTPParameters + ssrc SSRC writeStream TrackLocalWriter rtcpInterceptor interceptor.RTCPReader