From 21c3a864b6e2a2c0f0f6a0921c774119a0ea88a8 Mon Sep 17 00:00:00 2001 From: Maxime Piraux Date: Thu, 7 Nov 2019 11:25:36 +0100 Subject: [PATCH] 0-RTT packets can be coalesced with Initials --- agents/handshake_agent.go | 4 +- agents/recovery_agent.go | 2 +- agents/send_agent.go | 66 +++++++++++++++++++++++---- common.go | 5 ++ connection.go | 37 +++++++++------ packets.go | 16 +++++++ scenarii/ack_only.go | 2 +- scenarii/padding.go | 2 +- scenarii/stream_opening_reordering.go | 4 +- scenarii/unsupported_tls_version.go | 2 +- scenarii/version_negotiation.go | 4 +- scenarii/zero_rtt.go | 7 +-- 12 files changed, 114 insertions(+), 37 deletions(-) diff --git a/agents/handshake_agent.go b/agents/handshake_agent.go index 25fbd7d..a9b263b 100644 --- a/agents/handshake_agent.go +++ b/agents/handshake_agent.go @@ -55,7 +55,7 @@ func (a *HandshakeAgent) Run(conn *Connection) { select { case <-a.sendInitial: a.Logger.Println("Sending first Initial packet") - conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial) + conn.SendPacket.Submit(PacketToSend{Packet: conn.GetInitialPacket(), EncryptionLevel: EncryptionLevelInitial}) case p := <-incPackets: switch p := p.(type) { case *VersionNegotiationPacket: @@ -128,7 +128,7 @@ func (a *HandshakeAgent) Run(conn *Connection) { tlsStatus = a.TLSAgent.TLSStatus.RegisterNewChan(10) socketStatus = a.SocketAgent.SocketStatus.RegisterNewChan(10) conn.ConnectionRestarted = make(chan bool, 1) - conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial) + conn.SendPacket.Submit(PacketToSend{Packet: conn.GetInitialPacket(), EncryptionLevel: EncryptionLevelInitial}) case <-a.close: return } diff --git a/agents/recovery_agent.go b/agents/recovery_agent.go index 99ec669..b39453b 100644 --- a/agents/recovery_agent.go +++ b/agents/recovery_agent.go @@ -151,7 +151,7 @@ func (a *RecoveryAgent) RetransmitBatch(batch RetransmitBatch) { if b.Level == EncryptionLevelInitial && (len(b.Frames) > 200 || b.Frames[0].FrameType() == StreamType) { // Simple heuristic to detect first Initial packet packet := NewInitialPacket(a.conn) packet.Frames = b.Frames - a.conn.SendPacket(packet, EncryptionLevelInitial) + a.conn.SendPacket.Submit(PacketToSend{Packet: packet, EncryptionLevel: EncryptionLevelInitial}) return } for _, f := range b.Frames { diff --git a/agents/send_agent.go b/agents/send_agent.go index ccfd32b..fc9c84d 100644 --- a/agents/send_agent.go +++ b/agents/send_agent.go @@ -11,14 +11,16 @@ import ( // It also merge the ACK frames inside a given packet before sending. type SendingAgent struct { BaseAgent - MTU uint16 - FrameProducer []FrameProducer + MTU uint16 + FrameProducer []FrameProducer + DontCoalesceZeroRTT bool } func (a *SendingAgent) Run(conn *Connection) { a.Init("SendingAgent", conn.OriginalDestinationCID) preparePacket := conn.PreparePacket.RegisterNewChan(100) + sendPacket := conn.SendPacket.RegisterNewChan(100) newEncryptionLevelAvailable := conn.EncryptionLevelsAvailable.RegisterNewChan(10) encryptionLevels := []EncryptionLevel{EncryptionLevelInitial, EncryptionLevel0RTT, EncryptionLevelHandshake, EncryptionLevel1RTT} @@ -39,7 +41,9 @@ func (a *SendingAgent) Run(conn *Connection) { } } - fillAndSendPacket := func(packet Framer, level EncryptionLevel) { + initialSent := false + + fillPacket := func(packet Framer, level EncryptionLevel) Framer { spaceLeft := int(a.MTU) - packet.Header().HeaderLength() - conn.CryptoStates[level].Write.Overhead() addFrame: @@ -70,9 +74,9 @@ func (a *SendingAgent) Run(conn *Connection) { if len(packet.GetFrames()) == 0 { a.Logger.Printf("Preparing a packet for encryption level %s resulted in an empty packet, discarding\n", level.String()) conn.PacketNumber[packet.PNSpace()]-- // Avoids PN skipping - } else { - conn.SendPacket(packet, level) + return nil } + return packet } go func() { @@ -94,16 +98,31 @@ func (a *SendingAgent) Run(conn *Connection) { timersArmed[eL] = true } case <-timers[EncryptionLevelInitial].C: - fillAndSendPacket(NewInitialPacket(conn), EncryptionLevelInitial) + p := fillPacket(NewInitialPacket(conn), EncryptionLevelInitial) + if p != nil { + initialSent = true + conn.DoSendPacket(p, EncryptionLevelInitial) + } timersArmed[EncryptionLevelInitial] = false case <-timers[EncryptionLevel0RTT].C: - fillAndSendPacket(NewZeroRTTProtectedPacket(conn), EncryptionLevel0RTT) + if initialSent { + p := fillPacket(NewZeroRTTProtectedPacket(conn), EncryptionLevel0RTT) + if p != nil { + conn.DoSendPacket(p, EncryptionLevel0RTT) + } + } timersArmed[EncryptionLevel0RTT] = false case <-timers[EncryptionLevelHandshake].C: - fillAndSendPacket(NewHandshakePacket(conn), EncryptionLevelHandshake) + p := fillPacket(NewHandshakePacket(conn), EncryptionLevelHandshake) + if p != nil { + conn.DoSendPacket(p, EncryptionLevelHandshake) + } timersArmed[EncryptionLevelHandshake] = false case <-timers[EncryptionLevel1RTT].C: - fillAndSendPacket(NewProtectedPacket(conn), EncryptionLevel1RTT) + p := fillPacket(NewProtectedPacket(conn), EncryptionLevel1RTT) + if p != nil { + conn.DoSendPacket(p, EncryptionLevel1RTT) + } timersArmed[EncryptionLevel1RTT] = false case i := <-newEncryptionLevelAvailable: dEL := i.(DirectionalEncryptionLevel) @@ -115,6 +134,35 @@ func (a *SendingAgent) Run(conn *Connection) { bestEncryptionLevels[EncryptionLevelBest] = chooseBestEncryptionLevel(encryptionLevelsAvailable, false) bestEncryptionLevels[EncryptionLevelBestAppData] = chooseBestEncryptionLevel(encryptionLevelsAvailable, true) timers[eL].Reset(2 * time.Millisecond) + case i := <-sendPacket: + p := i.(PacketToSend) + if p.EncryptionLevel == EncryptionLevelInitial && p.Packet.Header().PacketType() == Initial { + initial := p.Packet.(*InitialPacket) + if !a.DontCoalesceZeroRTT && bestEncryptionLevels[EncryptionLevelBestAppData] == EncryptionLevel0RTT { + // Try to prepare a 0-RTT packet and squeeze it after the Initial + zp := NewZeroRTTProtectedPacket(conn) + fillPacket(zp, EncryptionLevel0RTT) + if len(zp.GetFrames()) > 0 { + zpBytes := conn.EncodeAndEncrypt(zp, EncryptionLevel0RTT) + initialFrames := initial.GetFrames() + initialLength := len(conn.EncodeAndEncrypt(initial, EncryptionLevelInitial)) + initial.Frames = nil + for _, f := range initialFrames { + if f.FrameType() != PaddingFrameType { + initial.Frames = append(initial.Frames, f) + } + } + initial.PadTo(initialLength - len(zpBytes)) + coalescedPackets := append(conn.EncodeAndEncrypt(initial, EncryptionLevelInitial), zpBytes...) + conn.UdpConnection.Write(coalescedPackets) + conn.PacketWasSent(initial) + conn.PacketWasSent(zp) + continue + } + } + initialSent = true + } + conn.DoSendPacket(p.Packet, p.EncryptionLevel) case <-a.close: return } diff --git a/common.go b/common.go index f5a85e0..c7a496d 100644 --- a/common.go +++ b/common.go @@ -249,4 +249,9 @@ type UnprocessedPayload struct { type QueuedFrame struct { Frame EncryptionLevel +} + +type PacketToSend struct { + Packet + EncryptionLevel } \ No newline at end of file diff --git a/connection.go b/connection.go index 7209ce8..0a0d738 100644 --- a/connection.go +++ b/connection.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "fmt" - . "github.com/QUIC-Tracker/quic-tracker/lib" "github.com/mpiraux/pigotls" "log" "net" @@ -48,6 +47,7 @@ type Connection struct { TransportParameters Broadcaster //type: QuicTransportParameters PreparePacket Broadcaster //type: EncryptionLevel + SendPacket Broadcaster //type: PacketToSend StreamInput Broadcaster //type: StreamInput ConnectionClosed chan bool @@ -82,10 +82,9 @@ func (c *Connection) nextPacketNumber(space PNSpace) PacketNumber { // TODO: Th c.PacketNumber[space]++ return pn } -func (c *Connection) SendPacket(packet Packet, level EncryptionLevel) { +func (c *Connection) EncodeAndEncrypt(packet Packet, level EncryptionLevel) []byte { switch packet.PNSpace() { case PNSpaceInitial, PNSpaceHandshake, PNSpaceAppData: - c.Logger.Printf("Sending packet {type=%s, number=%d}\n", packet.Header().PacketType().String(), packet.Header().PacketNumber()) cryptoState := c.CryptoStates[level] payload := packet.EncodePayload() @@ -109,12 +108,27 @@ func (c *Connection) SendPacket(packet Packet, level EncryptionLevel) { packetBytes[pnOffset+i] ^= mask[1+i] } - c.UdpConnection.Write(packetBytes) + return packetBytes + default: + // Clients do not send cleartext packets + } + return nil +} - if c.SentPacketHandler != nil { - c.SentPacketHandler(packet.Encode(packet.EncodePayload()), packet.Pointer()) - } - c.OutgoingPackets.Submit(packet) +func (c *Connection) PacketWasSent(packet Packet) { + if c.SentPacketHandler != nil { + c.SentPacketHandler(packet.Encode(packet.EncodePayload()), packet.Pointer()) + } + c.OutgoingPackets.Submit(packet) +} +func (c *Connection) DoSendPacket(packet Packet, level EncryptionLevel) { + switch packet.PNSpace() { + case PNSpaceInitial, PNSpaceHandshake, PNSpaceAppData: + c.Logger.Printf("Sending packet {type=%s, number=%d}\n", packet.Header().PacketType().String(), packet.Header().PacketNumber()) + + c.UdpConnection.Write(c.EncodeAndEncrypt(packet, level)) + + c.PacketWasSent(packet) default: // Clients do not send cleartext packets } @@ -150,11 +164,7 @@ func (c *Connection) GetInitialPacket() *InitialPacket { initialPacket := NewInitialPacket(c) initialPacket.Frames = append(initialPacket.Frames, cryptoFrame) - payloadLen := len(initialPacket.EncodePayload()) - paddingLength := initialLength - (len(initialPacket.header.Encode()) + int(VarIntLen(uint64(payloadLen))) + payloadLen + c.CryptoStates[EncryptionLevelInitial].Write.Overhead()) - for i := 0; i < paddingLength; i++ { - initialPacket.Frames = append(initialPacket.Frames, new(PaddingFrame)) - } + initialPacket.PadTo(initialLength - c.CryptoStates[EncryptionLevelInitial].Write.Overhead()) return initialPacket } @@ -311,6 +321,7 @@ func NewConnection(serverName string, version uint32, ALPN string, SCID []byte, c.ConnectionRestart = make(chan bool, 1) c.ConnectionRestarted = make(chan bool, 1) c.PreparePacket = NewBroadcaster(1000) + c.SendPacket = NewBroadcaster(1000) c.StreamInput = NewBroadcaster(1000) c.Logger = log.New(os.Stderr, fmt.Sprintf("[CID %s] ", hex.EncodeToString(c.OriginalDestinationCID)), log.Lshortfile) diff --git a/packets.go b/packets.go index b515f95..078e46d 100644 --- a/packets.go +++ b/packets.go @@ -125,6 +125,7 @@ type Framer interface { OnlyContains(frameType FrameType) bool GetFirst(frameType FrameType) Frame GetAll(frameType FrameType) []Frame + PadTo(length int) } type FramePacket struct { abstractPacket @@ -181,6 +182,21 @@ func (p *FramePacket) GetAll(frameType FrameType) []Frame { } return frames } +func (p *FramePacket) PadTo(length int) { + switch h := p.Header().(type) { + case *LongHeader: + h.Length = NewVarInt(uint64(len(p.EncodePayload()))) + } + currentLen := len(p.Encode(p.EncodePayload())) + for currentLen < length { + p.AddFrame(new(PaddingFrame)) + switch h := p.Header().(type) { + case *LongHeader: + h.Length = NewVarInt(h.Length.Value + 1) + } + currentLen = len(p.Encode(p.EncodePayload())) + } +} func (p *FramePacket) ShouldBeAcknowledged() bool { for _, frame := range p.Frames { switch frame.FrameType() { diff --git a/scenarii/ack_only.go b/scenarii/ack_only.go index 7f1b639..c732fa6 100644 --- a/scenarii/ack_only.go +++ b/scenarii/ack_only.go @@ -53,7 +53,7 @@ func (s *AckOnlyScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat } packet.AddFrame(ackFrame) - conn.SendPacket(packet, packet.EncryptionLevel()) + conn.DoSendPacket(packet, packet.EncryptionLevel()) if p.PNSpace() == qt.PNSpaceAppData { ackOnlyPackets = append(ackOnlyPackets, packet.Header().PacketNumber()) } diff --git a/scenarii/padding.go b/scenarii/padding.go index 8ad23df..ff701e3 100644 --- a/scenarii/padding.go +++ b/scenarii/padding.go @@ -37,7 +37,7 @@ func (s *PaddingScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat initialPacket.Frames = append(initialPacket.Frames, new(qt.PaddingFrame)) } - conn.SendPacket(initialPacket, qt.EncryptionLevelInitial) + conn.DoSendPacket(initialPacket, qt.EncryptionLevelInitial) } incPackets := conn.IncomingPackets.RegisterNewChan(1000) diff --git a/scenarii/stream_opening_reordering.go b/scenarii/stream_opening_reordering.go index 083f81d..4ca77ed 100644 --- a/scenarii/stream_opening_reordering.go +++ b/scenarii/stream_opening_reordering.go @@ -45,8 +45,8 @@ func (s *StreamOpeningReorderingScenario) Run(conn *qt.Connection, trace *qt.Tra pp2 := qt.NewProtectedPacket(conn) pp2.Frames = append(pp2.Frames, qt.NewStreamFrame(0, uint64(len(payload)), []byte{}, true)) - conn.SendPacket(pp2, qt.EncryptionLevel1RTT) - conn.SendPacket(pp1, qt.EncryptionLevel1RTT) + conn.DoSendPacket(pp2, qt.EncryptionLevel1RTT) + conn.DoSendPacket(pp1, qt.EncryptionLevel1RTT) forLoop: for { diff --git a/scenarii/unsupported_tls_version.go b/scenarii/unsupported_tls_version.go index 4df3dfc..c786221 100644 --- a/scenarii/unsupported_tls_version.go +++ b/scenarii/unsupported_tls_version.go @@ -78,5 +78,5 @@ func sendUnsupportedInitial(conn *qt.Connection) { frame.CryptoData = bytes.Replace(frame.CryptoData, []byte{0x0, 0x2b, 0x0, 0x03, 0x2, 0x03, 0x04}, []byte{0x0, 0x2b, 0x0, 0x03, 0x2, 0x7f, 0x00}, 1) } } - conn.SendPacket(initialPacket, qt.EncryptionLevelInitial) + conn.DoSendPacket(initialPacket, qt.EncryptionLevelInitial) } diff --git a/scenarii/version_negotiation.go b/scenarii/version_negotiation.go index bcacf1f..f621d7d 100644 --- a/scenarii/version_negotiation.go +++ b/scenarii/version_negotiation.go @@ -32,7 +32,7 @@ func (s *VersionNegotiationScenario) Run(conn *qt.Connection, trace *qt.Trace, p conn.Version = ForceVersionNegotiation trace.ErrorCode = VN_Timeout initial := conn.GetInitialPacket() - conn.SendPacket(initial, qt.EncryptionLevelInitial) + conn.DoSendPacket(initial, qt.EncryptionLevelInitial) threshold := 3 vnCount := 0 @@ -54,7 +54,7 @@ func (s *VersionNegotiationScenario) Run(conn *qt.Connection, trace *qt.Trace, p trace.Results["supported_versions"] = p.SupportedVersions // TODO: Compare versions announced ? newInitial := qt.NewInitialPacket(conn) newInitial.Frames = initial.Frames - conn.SendPacket(newInitial, qt.EncryptionLevelInitial) + conn.DoSendPacket(newInitial, qt.EncryptionLevelInitial) } case qt.Packet: trace.MarkError(VN_NotAnsweringToVN, "", p) diff --git a/scenarii/zero_rtt.go b/scenarii/zero_rtt.go index 59f9833..aca9f2a 100644 --- a/scenarii/zero_rtt.go +++ b/scenarii/zero_rtt.go @@ -79,16 +79,13 @@ func (s *ZeroRTTScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat incPackets = conn.IncomingPackets.RegisterNewChan(1000) encryptionLevelsAvailable := conn.EncryptionLevelsAvailable.RegisterNewChan(10) - handshakeAgent.InitiateHandshake() + responseChan := connAgents.AddHTTPAgent().SendRequest(preferredPath, "GET", trace.Host, nil) // TODO: Verify that this get effectively sent in a 0-RTT packet + handshakeAgent.InitiateHandshake() // TODO: Handle stateless connection if !s.waitFor0RTT(conn, trace, encryptionLevelsAvailable) { return } - // TODO: Handle stateless connection - - responseChan := connAgents.AddHTTPAgent().SendRequest(preferredPath, "GET", trace.Host, nil) // TODO: Verify that this get sent in a 0-RTT packet - trace.ErrorCode = ZR_DidntReceiveTheRequestedData for { select {