Skip to content

Commit

Permalink
0-RTT packets can be coalesced with Initials
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiraux committed Nov 7, 2019
1 parent 99f685f commit 21c3a86
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 37 deletions.
4 changes: 2 additions & 2 deletions agents/handshake_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (a *HandshakeAgent) Run(conn *Connection) {
select {
case <-a.sendInitial:
a.Logger.Println("Sending first Initial packet")
conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial)
conn.SendPacket.Submit(PacketToSend{Packet: conn.GetInitialPacket(), EncryptionLevel: EncryptionLevelInitial})
case p := <-incPackets:
switch p := p.(type) {
case *VersionNegotiationPacket:
Expand Down Expand Up @@ -128,7 +128,7 @@ func (a *HandshakeAgent) Run(conn *Connection) {
tlsStatus = a.TLSAgent.TLSStatus.RegisterNewChan(10)
socketStatus = a.SocketAgent.SocketStatus.RegisterNewChan(10)
conn.ConnectionRestarted = make(chan bool, 1)
conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial)
conn.SendPacket.Submit(PacketToSend{Packet: conn.GetInitialPacket(), EncryptionLevel: EncryptionLevelInitial})
case <-a.close:
return
}
Expand Down
2 changes: 1 addition & 1 deletion agents/recovery_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (a *RecoveryAgent) RetransmitBatch(batch RetransmitBatch) {
if b.Level == EncryptionLevelInitial && (len(b.Frames) > 200 || b.Frames[0].FrameType() == StreamType) { // Simple heuristic to detect first Initial packet
packet := NewInitialPacket(a.conn)
packet.Frames = b.Frames
a.conn.SendPacket(packet, EncryptionLevelInitial)
a.conn.SendPacket.Submit(PacketToSend{Packet: packet, EncryptionLevel: EncryptionLevelInitial})
return
}
for _, f := range b.Frames {
Expand Down
66 changes: 57 additions & 9 deletions agents/send_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ import (
// It also merge the ACK frames inside a given packet before sending.
type SendingAgent struct {
BaseAgent
MTU uint16
FrameProducer []FrameProducer
MTU uint16
FrameProducer []FrameProducer
DontCoalesceZeroRTT bool
}

func (a *SendingAgent) Run(conn *Connection) {
a.Init("SendingAgent", conn.OriginalDestinationCID)

preparePacket := conn.PreparePacket.RegisterNewChan(100)
sendPacket := conn.SendPacket.RegisterNewChan(100)
newEncryptionLevelAvailable := conn.EncryptionLevelsAvailable.RegisterNewChan(10)

encryptionLevels := []EncryptionLevel{EncryptionLevelInitial, EncryptionLevel0RTT, EncryptionLevelHandshake, EncryptionLevel1RTT}
Expand All @@ -39,7 +41,9 @@ func (a *SendingAgent) Run(conn *Connection) {
}
}

fillAndSendPacket := func(packet Framer, level EncryptionLevel) {
initialSent := false

fillPacket := func(packet Framer, level EncryptionLevel) Framer {
spaceLeft := int(a.MTU) - packet.Header().HeaderLength() - conn.CryptoStates[level].Write.Overhead()

addFrame:
Expand Down Expand Up @@ -70,9 +74,9 @@ func (a *SendingAgent) Run(conn *Connection) {
if len(packet.GetFrames()) == 0 {
a.Logger.Printf("Preparing a packet for encryption level %s resulted in an empty packet, discarding\n", level.String())
conn.PacketNumber[packet.PNSpace()]-- // Avoids PN skipping
} else {
conn.SendPacket(packet, level)
return nil
}
return packet
}

go func() {
Expand All @@ -94,16 +98,31 @@ func (a *SendingAgent) Run(conn *Connection) {
timersArmed[eL] = true
}
case <-timers[EncryptionLevelInitial].C:
fillAndSendPacket(NewInitialPacket(conn), EncryptionLevelInitial)
p := fillPacket(NewInitialPacket(conn), EncryptionLevelInitial)
if p != nil {
initialSent = true
conn.DoSendPacket(p, EncryptionLevelInitial)
}
timersArmed[EncryptionLevelInitial] = false
case <-timers[EncryptionLevel0RTT].C:
fillAndSendPacket(NewZeroRTTProtectedPacket(conn), EncryptionLevel0RTT)
if initialSent {
p := fillPacket(NewZeroRTTProtectedPacket(conn), EncryptionLevel0RTT)
if p != nil {
conn.DoSendPacket(p, EncryptionLevel0RTT)
}
}
timersArmed[EncryptionLevel0RTT] = false
case <-timers[EncryptionLevelHandshake].C:
fillAndSendPacket(NewHandshakePacket(conn), EncryptionLevelHandshake)
p := fillPacket(NewHandshakePacket(conn), EncryptionLevelHandshake)
if p != nil {
conn.DoSendPacket(p, EncryptionLevelHandshake)
}
timersArmed[EncryptionLevelHandshake] = false
case <-timers[EncryptionLevel1RTT].C:
fillAndSendPacket(NewProtectedPacket(conn), EncryptionLevel1RTT)
p := fillPacket(NewProtectedPacket(conn), EncryptionLevel1RTT)
if p != nil {
conn.DoSendPacket(p, EncryptionLevel1RTT)
}
timersArmed[EncryptionLevel1RTT] = false
case i := <-newEncryptionLevelAvailable:
dEL := i.(DirectionalEncryptionLevel)
Expand All @@ -115,6 +134,35 @@ func (a *SendingAgent) Run(conn *Connection) {
bestEncryptionLevels[EncryptionLevelBest] = chooseBestEncryptionLevel(encryptionLevelsAvailable, false)
bestEncryptionLevels[EncryptionLevelBestAppData] = chooseBestEncryptionLevel(encryptionLevelsAvailable, true)
timers[eL].Reset(2 * time.Millisecond)
case i := <-sendPacket:
p := i.(PacketToSend)
if p.EncryptionLevel == EncryptionLevelInitial && p.Packet.Header().PacketType() == Initial {
initial := p.Packet.(*InitialPacket)
if !a.DontCoalesceZeroRTT && bestEncryptionLevels[EncryptionLevelBestAppData] == EncryptionLevel0RTT {
// Try to prepare a 0-RTT packet and squeeze it after the Initial
zp := NewZeroRTTProtectedPacket(conn)
fillPacket(zp, EncryptionLevel0RTT)
if len(zp.GetFrames()) > 0 {
zpBytes := conn.EncodeAndEncrypt(zp, EncryptionLevel0RTT)
initialFrames := initial.GetFrames()
initialLength := len(conn.EncodeAndEncrypt(initial, EncryptionLevelInitial))
initial.Frames = nil
for _, f := range initialFrames {
if f.FrameType() != PaddingFrameType {
initial.Frames = append(initial.Frames, f)
}
}
initial.PadTo(initialLength - len(zpBytes))
coalescedPackets := append(conn.EncodeAndEncrypt(initial, EncryptionLevelInitial), zpBytes...)
conn.UdpConnection.Write(coalescedPackets)
conn.PacketWasSent(initial)
conn.PacketWasSent(zp)
continue
}
}
initialSent = true
}
conn.DoSendPacket(p.Packet, p.EncryptionLevel)
case <-a.close:
return
}
Expand Down
5 changes: 5 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,9 @@ type UnprocessedPayload struct {
type QueuedFrame struct {
Frame
EncryptionLevel
}

type PacketToSend struct {
Packet
EncryptionLevel
}
37 changes: 24 additions & 13 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/hex"
"errors"
"fmt"
. "github.com/QUIC-Tracker/quic-tracker/lib"
"github.com/mpiraux/pigotls"
"log"
"net"
Expand Down Expand Up @@ -48,6 +47,7 @@ type Connection struct {
TransportParameters Broadcaster //type: QuicTransportParameters

PreparePacket Broadcaster //type: EncryptionLevel
SendPacket Broadcaster //type: PacketToSend
StreamInput Broadcaster //type: StreamInput

ConnectionClosed chan bool
Expand Down Expand Up @@ -82,10 +82,9 @@ func (c *Connection) nextPacketNumber(space PNSpace) PacketNumber { // TODO: Th
c.PacketNumber[space]++
return pn
}
func (c *Connection) SendPacket(packet Packet, level EncryptionLevel) {
func (c *Connection) EncodeAndEncrypt(packet Packet, level EncryptionLevel) []byte {
switch packet.PNSpace() {
case PNSpaceInitial, PNSpaceHandshake, PNSpaceAppData:
c.Logger.Printf("Sending packet {type=%s, number=%d}\n", packet.Header().PacketType().String(), packet.Header().PacketNumber())
cryptoState := c.CryptoStates[level]

payload := packet.EncodePayload()
Expand All @@ -109,12 +108,27 @@ func (c *Connection) SendPacket(packet Packet, level EncryptionLevel) {
packetBytes[pnOffset+i] ^= mask[1+i]
}

c.UdpConnection.Write(packetBytes)
return packetBytes
default:
// Clients do not send cleartext packets
}
return nil
}

if c.SentPacketHandler != nil {
c.SentPacketHandler(packet.Encode(packet.EncodePayload()), packet.Pointer())
}
c.OutgoingPackets.Submit(packet)
func (c *Connection) PacketWasSent(packet Packet) {
if c.SentPacketHandler != nil {
c.SentPacketHandler(packet.Encode(packet.EncodePayload()), packet.Pointer())
}
c.OutgoingPackets.Submit(packet)
}
func (c *Connection) DoSendPacket(packet Packet, level EncryptionLevel) {
switch packet.PNSpace() {
case PNSpaceInitial, PNSpaceHandshake, PNSpaceAppData:
c.Logger.Printf("Sending packet {type=%s, number=%d}\n", packet.Header().PacketType().String(), packet.Header().PacketNumber())

c.UdpConnection.Write(c.EncodeAndEncrypt(packet, level))

c.PacketWasSent(packet)
default:
// Clients do not send cleartext packets
}
Expand Down Expand Up @@ -150,11 +164,7 @@ func (c *Connection) GetInitialPacket() *InitialPacket {

initialPacket := NewInitialPacket(c)
initialPacket.Frames = append(initialPacket.Frames, cryptoFrame)
payloadLen := len(initialPacket.EncodePayload())
paddingLength := initialLength - (len(initialPacket.header.Encode()) + int(VarIntLen(uint64(payloadLen))) + payloadLen + c.CryptoStates[EncryptionLevelInitial].Write.Overhead())
for i := 0; i < paddingLength; i++ {
initialPacket.Frames = append(initialPacket.Frames, new(PaddingFrame))
}
initialPacket.PadTo(initialLength - c.CryptoStates[EncryptionLevelInitial].Write.Overhead())

return initialPacket
}
Expand Down Expand Up @@ -311,6 +321,7 @@ func NewConnection(serverName string, version uint32, ALPN string, SCID []byte,
c.ConnectionRestart = make(chan bool, 1)
c.ConnectionRestarted = make(chan bool, 1)
c.PreparePacket = NewBroadcaster(1000)
c.SendPacket = NewBroadcaster(1000)
c.StreamInput = NewBroadcaster(1000)

c.Logger = log.New(os.Stderr, fmt.Sprintf("[CID %s] ", hex.EncodeToString(c.OriginalDestinationCID)), log.Lshortfile)
Expand Down
16 changes: 16 additions & 0 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ type Framer interface {
OnlyContains(frameType FrameType) bool
GetFirst(frameType FrameType) Frame
GetAll(frameType FrameType) []Frame
PadTo(length int)
}
type FramePacket struct {
abstractPacket
Expand Down Expand Up @@ -181,6 +182,21 @@ func (p *FramePacket) GetAll(frameType FrameType) []Frame {
}
return frames
}
func (p *FramePacket) PadTo(length int) {
switch h := p.Header().(type) {
case *LongHeader:
h.Length = NewVarInt(uint64(len(p.EncodePayload())))
}
currentLen := len(p.Encode(p.EncodePayload()))
for currentLen < length {
p.AddFrame(new(PaddingFrame))
switch h := p.Header().(type) {
case *LongHeader:
h.Length = NewVarInt(h.Length.Value + 1)
}
currentLen = len(p.Encode(p.EncodePayload()))
}
}
func (p *FramePacket) ShouldBeAcknowledged() bool {
for _, frame := range p.Frames {
switch frame.FrameType() {
Expand Down
2 changes: 1 addition & 1 deletion scenarii/ack_only.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (s *AckOnlyScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat
}

packet.AddFrame(ackFrame)
conn.SendPacket(packet, packet.EncryptionLevel())
conn.DoSendPacket(packet, packet.EncryptionLevel())
if p.PNSpace() == qt.PNSpaceAppData {
ackOnlyPackets = append(ackOnlyPackets, packet.Header().PacketNumber())
}
Expand Down
2 changes: 1 addition & 1 deletion scenarii/padding.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (s *PaddingScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat
initialPacket.Frames = append(initialPacket.Frames, new(qt.PaddingFrame))
}

conn.SendPacket(initialPacket, qt.EncryptionLevelInitial)
conn.DoSendPacket(initialPacket, qt.EncryptionLevelInitial)
}

incPackets := conn.IncomingPackets.RegisterNewChan(1000)
Expand Down
4 changes: 2 additions & 2 deletions scenarii/stream_opening_reordering.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func (s *StreamOpeningReorderingScenario) Run(conn *qt.Connection, trace *qt.Tra
pp2 := qt.NewProtectedPacket(conn)
pp2.Frames = append(pp2.Frames, qt.NewStreamFrame(0, uint64(len(payload)), []byte{}, true))

conn.SendPacket(pp2, qt.EncryptionLevel1RTT)
conn.SendPacket(pp1, qt.EncryptionLevel1RTT)
conn.DoSendPacket(pp2, qt.EncryptionLevel1RTT)
conn.DoSendPacket(pp1, qt.EncryptionLevel1RTT)

forLoop:
for {
Expand Down
2 changes: 1 addition & 1 deletion scenarii/unsupported_tls_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ func sendUnsupportedInitial(conn *qt.Connection) {
frame.CryptoData = bytes.Replace(frame.CryptoData, []byte{0x0, 0x2b, 0x0, 0x03, 0x2, 0x03, 0x04}, []byte{0x0, 0x2b, 0x0, 0x03, 0x2, 0x7f, 0x00}, 1)
}
}
conn.SendPacket(initialPacket, qt.EncryptionLevelInitial)
conn.DoSendPacket(initialPacket, qt.EncryptionLevelInitial)
}
4 changes: 2 additions & 2 deletions scenarii/version_negotiation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (s *VersionNegotiationScenario) Run(conn *qt.Connection, trace *qt.Trace, p
conn.Version = ForceVersionNegotiation
trace.ErrorCode = VN_Timeout
initial := conn.GetInitialPacket()
conn.SendPacket(initial, qt.EncryptionLevelInitial)
conn.DoSendPacket(initial, qt.EncryptionLevelInitial)

threshold := 3
vnCount := 0
Expand All @@ -54,7 +54,7 @@ func (s *VersionNegotiationScenario) Run(conn *qt.Connection, trace *qt.Trace, p
trace.Results["supported_versions"] = p.SupportedVersions // TODO: Compare versions announced ?
newInitial := qt.NewInitialPacket(conn)
newInitial.Frames = initial.Frames
conn.SendPacket(newInitial, qt.EncryptionLevelInitial)
conn.DoSendPacket(newInitial, qt.EncryptionLevelInitial)
}
case qt.Packet:
trace.MarkError(VN_NotAnsweringToVN, "", p)
Expand Down
7 changes: 2 additions & 5 deletions scenarii/zero_rtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ func (s *ZeroRTTScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat
incPackets = conn.IncomingPackets.RegisterNewChan(1000)
encryptionLevelsAvailable := conn.EncryptionLevelsAvailable.RegisterNewChan(10)

handshakeAgent.InitiateHandshake()
responseChan := connAgents.AddHTTPAgent().SendRequest(preferredPath, "GET", trace.Host, nil) // TODO: Verify that this get effectively sent in a 0-RTT packet
handshakeAgent.InitiateHandshake() // TODO: Handle stateless connection

if !s.waitFor0RTT(conn, trace, encryptionLevelsAvailable) {
return
}

// TODO: Handle stateless connection

responseChan := connAgents.AddHTTPAgent().SendRequest(preferredPath, "GET", trace.Host, nil) // TODO: Verify that this get sent in a 0-RTT packet

trace.ErrorCode = ZR_DidntReceiveTheRequestedData
for {
select {
Expand Down

0 comments on commit 21c3a86

Please sign in to comment.