From 555670f80efb1e102dd6294a732894c5f64ca3e6 Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Wed, 10 Jul 2024 23:08:15 -0400 Subject: [PATCH] On Read Retransmit send FSM to SENDING RFC6347 Section-4.2.4 states ``` The implementation reads a retransmitted flight from the peer: the implementation transitions to the SENDING state, where it retransmits the flight, resets the retransmit timer, and returns to the WAITING state. The rationale here is that the receipt of a duplicate message is the likely result of timer expiry on the peer and therefore suggests that part of one's previous flight was lost. ``` Resolves #478 --- conn.go | 25 ++++++++++++++++++------- fragment_buffer.go | 17 +++++++++++------ fragment_buffer_test.go | 6 +++--- handshaker.go | 5 ++++- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/conn.go b/conn.go index d82228f31..bc54eaae0 100644 --- a/conn.go +++ b/conn.go @@ -878,7 +878,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A } } - isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) + isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...)) if err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] @@ -886,13 +886,24 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A return false, nil, nil } else if isHandshake { markPacketAsValid() - for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { - header := &handshake.Header{} - if err := header.Unmarshal(out); err != nil { - c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) - continue + + if isRetransmit { + // The implementation reads a retransmitted flight from the peer: the + // implementation transitions to the SENDING state + // [RFC6347 Section-4.2.4] + select { + case c.fsm.readRetransmit <- struct{}{}: + default: + } + } else { + for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { + header := &handshake.Header{} + if err := header.Unmarshal(out); err != nil { + c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) + continue + } + c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } - c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } return true, nil, nil diff --git a/fragment_buffer.go b/fragment_buffer.go index fb5af6c3d..50506efc6 100644 --- a/fragment_buffer.go +++ b/fragment_buffer.go @@ -43,24 +43,29 @@ func (f *fragmentBuffer) size() int { // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled // when an error returns it is fatal, and the DTLS connection should be stopped -func (f *fragmentBuffer) push(buf []byte) (bool, error) { +func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { if f.size()+len(buf) >= fragmentBufferMaxSize { - return false, errFragmentBufferOverflow + return false, false, errFragmentBufferOverflow } frag := new(fragment) if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { - return false, err + return false, false, err } // fragment isn't a handshake, we don't need to handle it if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { - return false, nil + return false, false, nil } for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { if err := frag.handshakeHeader.Unmarshal(buf); err != nil { - return false, err + return false, false, err + } + + // Fragment is a retransmission, we have already assembled it so ignoring + if frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber { + return true, true, nil } if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { @@ -80,7 +85,7 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) { buf = buf[end:] } - return true, nil + return true, false, nil } func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { diff --git a/fragment_buffer_test.go b/fragment_buffer_test.go index ad8834e71..2b2f62c7e 100644 --- a/fragment_buffer_test.go +++ b/fragment_buffer_test.go @@ -94,7 +94,7 @@ func TestFragmentBuffer(t *testing.T) { } { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { - status, err := fragmentBuffer.push(frag) + status, _, err := fragmentBuffer.push(frag) if err != nil { t.Error(err) } else if !status { @@ -122,13 +122,13 @@ func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits - if _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { + if _, _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { t.Fatal(err) } // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) - if _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { + if _, _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { t.Fatalf("Pushing a large buffer returned (%s) expected(%s)", err, errFragmentBufferOverflow) } } diff --git a/handshaker.go b/handshaker.go index 4e0f1ad95..00756babe 100644 --- a/handshaker.go +++ b/handshaker.go @@ -90,6 +90,7 @@ type handshakeFSM struct { cache *handshakeCache cfg *handshakeConfig closed chan struct{} + readRetransmit chan struct{} } type handshakeConfig struct { @@ -173,6 +174,7 @@ func newHandshakeFSM( cfg: cfg, retransmitInterval: cfg.initialRetransmitInterval, closed: make(chan struct{}), + readRetransmit: make(chan struct{}, 1), } } @@ -303,7 +305,6 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, } s.currentFlight = nextFlight return handshakePreparing, nil - case <-retransmitTimer.C: if !s.retransmit { return handshakeWaiting, nil @@ -319,6 +320,8 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, s.retransmitInterval = time.Second * 60 } return handshakeSending, nil + case <-s.readRetransmit: + return handshakeSending, nil case <-ctx.Done(): s.retransmitInterval = s.cfg.initialRetransmitInterval return handshakeErrored, ctx.Err()