From a39284c11284dc3e04764f6a4f1340ed2046eccc Mon Sep 17 00:00:00 2001 From: boks1971 Date: Wed, 15 Nov 2023 09:47:37 +0530 Subject: [PATCH] Use atomic to avoid stale SRTP protection profile `state` is acccessed without lock in the FSM. In some cases, that leads to stale values. For example, `srtpProtectionProfile` is set in flight handlers (differnt flight handlers in client and server). But, when it is accessed via the API `SelectedSRTPProtectionProfile`, it gets a stale value as it appears that the two goroutines are out-of-sync on that piece of shared memory. This is a larger concern for use of `state`. Ideally, either - `state` should have a lock internally and all fields are accessed through methods. - carefully split fields of `state` to ensure process access/sync. Doing the smaller change here to address one field that has seen stale value. --- conn.go | 8 +++----- flight0handler.go | 2 +- flight3handler.go | 4 ++-- flight4bhandler.go | 4 ++-- flight4handler.go | 4 ++-- state.go | 18 +++++++++++++++--- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/conn.go b/conn.go index 9d1da84cb..b0d8cde4f 100644 --- a/conn.go +++ b/conn.go @@ -394,14 +394,12 @@ func (c *Conn) ConnectionState() State { // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { - c.lock.RLock() - defer c.lock.RUnlock() - - if c.state.srtpProtectionProfile == 0 { + profile := c.state.getSRTPProtectionProfile() + if profile == 0 { return 0, false } - return c.state.srtpProtectionProfile, true + return profile, true } func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { diff --git a/flight0handler.go b/flight0handler.go index 648c52883..0a45c58d4 100644 --- a/flight0handler.go +++ b/flight0handler.go @@ -66,7 +66,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak if !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } - state.srtpProtectionProfile = profile + state.setSRTPProtectionProfile(profile) case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true diff --git a/flight3handler.go b/flight3handler.go index 920ee73bd..90dc1a6e3 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -56,7 +56,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } - state.srtpProtectionProfile = profile + state.setSRTPProtectionProfile(profile) case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true @@ -83,7 +83,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS } - if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 { + if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension } diff --git a/flight4bhandler.go b/flight4bhandler.go index 6bbbc5972..6b1b90469 100644 --- a/flight4bhandler.go +++ b/flight4bhandler.go @@ -59,9 +59,9 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha Supported: true, }) } - if state.srtpProtectionProfile != 0 { + if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, }) } diff --git a/flight4handler.go b/flight4handler.go index 52568139f..cd8f2884a 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -228,9 +228,9 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha Supported: true, }) } - if state.srtpProtectionProfile != 0 { + if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, }) } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { diff --git a/state.go b/state.go index a65d426bc..b04045ac9 100644 --- a/state.go +++ b/state.go @@ -24,7 +24,7 @@ type State struct { cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen CipherSuiteID CipherSuiteID - srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile + srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile PeerCertificates [][]byte IdentityHint []byte SessionID []byte @@ -106,7 +106,7 @@ func (s *State) serialize() *serializedState { SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]), LocalRandom: localRnd, RemoteRandom: remoteRnd, - SRTPProtectionProfile: uint16(s.srtpProtectionProfile), + SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()), PeerCertificates: s.PeerCertificates, IdentityHint: s.IdentityHint, SessionID: s.SessionID, @@ -145,7 +145,7 @@ func (s *State) deserialize(serialized serializedState) { s.cipherSuite = cipherSuiteForID(s.CipherSuiteID, nil) atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber) - s.srtpProtectionProfile = SRTPProtectionProfile(serialized.SRTPProtectionProfile) + s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile)) // Set remote certificate s.PeerCertificates = serialized.PeerCertificates @@ -242,3 +242,15 @@ func (s *State) getLocalEpoch() uint16 { } return 0 } + +func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) { + s.srtpProtectionProfile.Store(profile) +} + +func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile { + if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok { + return val + } + + return 0 +}