From b215436dbb6de5d3366d67cbc2c73302d700e1b4 Mon Sep 17 00:00:00 2001 From: Maxime Piraux Date: Sat, 23 Mar 2019 17:11:51 +0100 Subject: [PATCH] Adds a spin bit test --- connection.go | 5 +++- headers.go | 10 +++++++ scenarii/scenario.go | 1 + scenarii/spin_bit.go | 62 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 scenarii/spin_bit.go diff --git a/connection.go b/connection.go index ad1e5d9..fd2b08c 100644 --- a/connection.go +++ b/connection.go @@ -23,7 +23,10 @@ type Connection struct { Tls *pigotls.Connection TLSTPHandler *TLSTransportParameterHandler - KeyPhaseIndex uint + + KeyPhaseIndex uint + SpinBit SpinBit + LastSpinNumber PacketNumber CryptoStates map[EncryptionLevel]*CryptoState diff --git a/headers.go b/headers.go index 9e3e9fe..0d7e416 100644 --- a/headers.go +++ b/headers.go @@ -149,6 +149,7 @@ func (t PacketType) PNSpace() PNSpace { } type ShortHeader struct { + SpinBit SpinBit KeyPhase KeyPhaseBit DestinationCID ConnectionID truncatedPN TruncatedPN @@ -158,6 +159,9 @@ func (h *ShortHeader) Encode() []byte { buffer := new(bytes.Buffer) var typeByte uint8 typeByte |= 0x40 + if h.SpinBit == SpinValueOne { + typeByte |= 0x20 + } if h.KeyPhase == KeyPhaseOne { typeByte |= 0x04 } @@ -177,6 +181,7 @@ func (h *ShortHeader) HeaderLength() int { return 1 + len(h. func ReadShortHeader(buffer *bytes.Reader, conn *Connection) *ShortHeader { h := new(ShortHeader) typeByte, _ := buffer.ReadByte() + h.SpinBit = (typeByte & 0x20) == 0x20 h.KeyPhase = (typeByte & 0x04) == 0x04 h.DestinationCID = make([]byte, len(conn.SourceCID)) @@ -187,6 +192,7 @@ func ReadShortHeader(buffer *bytes.Reader, conn *Connection) *ShortHeader { } func NewShortHeader(conn *Connection) *ShortHeader { h := new(ShortHeader) + h.SpinBit = conn.SpinBit h.KeyPhase = conn.KeyPhaseIndex % 2 == 1 h.DestinationCID = conn.DestinationCID h.packetNumber = conn.nextPacketNumber(PNSpaceAppData) @@ -197,3 +203,7 @@ func NewShortHeader(conn *Connection) *ShortHeader { type KeyPhaseBit bool const KeyPhaseZero KeyPhaseBit = false const KeyPhaseOne KeyPhaseBit = true + +type SpinBit bool +const SpinValueZero SpinBit = false +const SpinValueOne SpinBit = true diff --git a/scenarii/scenario.go b/scenarii/scenario.go index e6b09ee..a780c5e 100644 --- a/scenarii/scenario.go +++ b/scenarii/scenario.go @@ -124,5 +124,6 @@ func GetAllScenarii() map[string]Scenario { "http3_encoder_stream": NewHTTP3EncoderStreamScenario(), "http3_uni_streams_limits": NewHTTP3UniStreamsLimitsScenario(), "http3_reserved_frames": NewHTTP3ReservedFramesScenario(), + "spin_bit": NewSpinBitScenario(), } } diff --git a/scenarii/spin_bit.go b/scenarii/spin_bit.go new file mode 100644 index 0000000..55e0f06 --- /dev/null +++ b/scenarii/spin_bit.go @@ -0,0 +1,62 @@ +package scenarii + +import ( + . "github.com/QUIC-Tracker/quic-tracker" +) + +const ( + SB_TLSHandshakeFailed = 1 + SB_DoesNotSpin = 2 +) + +type SpinBitScenario struct { + AbstractScenario +} + +func NewSpinBitScenario() *SpinBitScenario { + return &SpinBitScenario{AbstractScenario{name: "spin_bit", version: 1, ipv6: false}} +} +func (s *SpinBitScenario) Run(conn *Connection, trace *Trace, preferredPath string, debug bool) { + connAgents := s.CompleteHandshake(conn, trace, SB_TLSHandshakeFailed) + if connAgents == nil { + return + } + defer connAgents.CloseConnection(false, 0, "") + + incomingPackets := conn.IncomingPackets.RegisterNewChan(1000) + + conn.SendHTTP09GETRequest(preferredPath, 0) + + var lastServerSpin SpinBit + spins := 0 + +forLoop: + for { + select { + case i := <-incomingPackets: + switch p := i.(type) { + case *ProtectedPacket: + hdr := p.Header().(*ShortHeader) + if hdr.PacketNumber() > conn.LastSpinNumber { + if hdr.SpinBit != lastServerSpin { + lastServerSpin = hdr.SpinBit + spins++ + } + conn.SpinBit = !hdr.SpinBit + conn.LastSpinNumber = hdr.PacketNumber() + } + if conn.Streams.Get(0).ReadClosed && !conn.Streams.Get(4).WriteClosed { + conn.SendHTTP09GETRequest(preferredPath, 4) + } + } + case <-conn.ConnectionClosed: + break forLoop + case <-s.Timeout(): + break forLoop + } + } + + if spins <= 1 { + trace.ErrorCode = SB_DoesNotSpin + } +}