Skip to content

Commit

Permalink
Adds a lock to CryptoStates, fixes #23
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiraux committed Apr 28, 2020
1 parent 2cbac1f commit 1da4265
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion agents/parse_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (a *ParsingAgent) Run(conn *Connection) {
}

header := ReadHeader(bytes.NewReader(ciphertext), a.conn)
cryptoState := a.conn.CryptoStates[header.EncryptionLevel()]
cryptoState := a.conn.CryptoState(header.EncryptionLevel())

switch header.PacketType() {
case Initial, Handshake, ZeroRTTProtected, ShortHeaderPacket: // Decrypt PN
Expand Down
6 changes: 3 additions & 3 deletions agents/send_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (a *SendingAgent) Run(conn *Connection) {
initialSent := false

fillPacket := func(packet Framer, level EncryptionLevel) Framer {
spaceLeft := int(a.MTU) - packet.Header().HeaderLength() - conn.CryptoStates[level].Write.Overhead()
spaceLeft := int(a.MTU) - packet.Header().HeaderLength() - conn.CryptoState(level).Write.Overhead()

addFrame:
for i, fp := range a.FrameProducer {
Expand Down Expand Up @@ -114,7 +114,7 @@ func (a *SendingAgent) Run(conn *Connection) {
} else {
initialLength = MinimumInitialLength
}
initialLength -= conn.CryptoStates[EncryptionLevelInitial].Write.Overhead()
initialLength -= conn.CryptoState(EncryptionLevelInitial).Write.Overhead()
p.PadTo(initialLength)
initialSent = true
conn.DoSendPacket(p, EncryptionLevelInitial)
Expand Down Expand Up @@ -191,7 +191,7 @@ func (a *SendingAgent) Run(conn *Connection) {
} else {
initialLength = MinimumInitialLength
}
initialLength -= conn.CryptoStates[EncryptionLevelInitial].Write.Overhead()
initialLength -= conn.CryptoState(EncryptionLevelInitial).Write.Overhead()
initial.PadTo(initialLength)
initialSent = true
}
Expand Down
2 changes: 2 additions & 0 deletions agents/tls_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func (a *TLSAgent) Run(conn *Connection) {
a.TLSStatus.Submit(TLSStatus{false, packet, err})
}

conn.CryptoStateLock.Lock()
if conn.CryptoStates[EncryptionLevelHandshake] == nil {
conn.CryptoStates[EncryptionLevelHandshake] = new(CryptoState)
}
Expand Down Expand Up @@ -126,6 +127,7 @@ func (a *TLSAgent) Run(conn *Connection) {
conn.EncryptionLevels.Submit(*e)
}
}
conn.CryptoStateLock.Unlock()

if !resumptionTicketSent && len(conn.Tls.ResumptionTicket()) > 0 {
a.ResumptionTicket.Submit(conn.Tls.ResumptionTicket())
Expand Down
19 changes: 17 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Connection struct {
SpinBit SpinBit
LastSpinNumber PacketNumber

CryptoStateLock sync.Locker
CryptoStates map[EncryptionLevel]*CryptoState

ReceivedPacketHandler func([]byte, unsafe.Pointer)
Expand Down Expand Up @@ -90,10 +91,19 @@ func (c *Connection) nextPacketNumber(space PNSpace) PacketNumber { // TODO: Th
c.PacketNumberLock.Unlock()
return pn
}
func (c *Connection) CryptoState(level EncryptionLevel) *CryptoState {
c.CryptoStateLock.Lock()
cs, ok := c.CryptoStates[level]
c.CryptoStateLock.Unlock()
if ok {
return cs
}
return nil
}
func (c *Connection) EncodeAndEncrypt(packet Packet, level EncryptionLevel) []byte {
switch packet.PNSpace() {
case PNSpaceInitial, PNSpaceHandshake, PNSpaceAppData:
cryptoState := c.CryptoStates[level]
cryptoState := c.CryptoState(level)

payload := packet.EncodePayload()
if h, ok := packet.Header().(*LongHeader); ok {
Expand Down Expand Up @@ -161,7 +171,9 @@ func (c *Connection) GetInitialPacket() *InitialPacket {

if len(c.Tls.ZeroRTTSecret()) > 0 {
c.Logger.Printf("0-RTT secret is available, installing crypto state")
c.CryptoStateLock.Lock()
c.CryptoStates[EncryptionLevel0RTT] = NewProtectedCryptoState(c.Tls, nil, c.Tls.ZeroRTTSecret())
c.CryptoStateLock.Unlock()
c.EncryptionLevels.Submit(DirectionalEncryptionLevel{EncryptionLevel: EncryptionLevel0RTT, Read: false, Available: true})
}

Expand All @@ -174,7 +186,7 @@ func (c *Connection) GetInitialPacket() *InitialPacket {

initialPacket := NewInitialPacket(c)
initialPacket.Frames = append(initialPacket.Frames, cryptoFrame)
initialPacket.PadTo(initialLength - c.CryptoStates[EncryptionLevelInitial].Write.Overhead())
initialPacket.PadTo(initialLength - c.CryptoState(EncryptionLevelInitial).Write.Overhead())

return initialPacket
}
Expand Down Expand Up @@ -250,9 +262,12 @@ func (c *Connection) TransitionTo(version uint32, ALPN string) {
c.AckQueue[space] = nil
}

c.CryptoStateLock = &sync.Mutex{}
c.CryptoStateLock.Lock()
c.CryptoStates = make(map[EncryptionLevel]*CryptoState)
c.CryptoStreams = make(map[PNSpace]*Stream)
c.CryptoStates[EncryptionLevelInitial] = NewInitialPacketProtection(c)
c.CryptoStateLock.Unlock()
c.Streams = Streams{streams: make(map[uint64]*Stream), lock: &sync.Mutex{}, input: &c.StreamInput}
}
func (c *Connection) CloseConnection(quicLayer bool, errCode uint64, reasonPhrase string) {
Expand Down
2 changes: 2 additions & 0 deletions scenarii/key_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ forLoop1:
readSecret := conn.Tls.HkdfExpandLabel(conn.Tls.ProtectedReadSecret(), "ku", nil, conn.Tls.HashDigestSize(), pigotls.QuicBaseLabel)
writeSecret := conn.Tls.HkdfExpandLabel(conn.Tls.ProtectedWriteSecret(), "ku", nil, conn.Tls.HashDigestSize(), pigotls.QuicBaseLabel)

conn.CryptoStateLock.Lock()
oldState := conn.CryptoStates[qt.EncryptionLevel1RTT]

conn.CryptoStates[qt.EncryptionLevel1RTT] = qt.NewProtectedCryptoState(conn.Tls, readSecret, writeSecret)
conn.CryptoStates[qt.EncryptionLevel1RTT].HeaderRead = oldState.HeaderRead
conn.CryptoStates[qt.EncryptionLevel1RTT].HeaderWrite = oldState.HeaderWrite
conn.KeyPhaseIndex++
conn.CryptoStateLock.Unlock()

responseChan := connAgents.AddHTTPAgent().SendRequest(preferredPath, "GET", trace.Host, nil)

Expand Down
2 changes: 1 addition & 1 deletion scenarii/padding.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (s *PaddingScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat

initialPacket := qt.NewInitialPacket(conn)
payloadLen := len(initialPacket.EncodePayload())
paddingLength := initialLength - (len(initialPacket.Header().Encode()) + int(VarIntLen(uint64(payloadLen))) + payloadLen + conn.CryptoStates[qt.EncryptionLevelInitial].Write.Overhead())
paddingLength := initialLength - (len(initialPacket.Header().Encode()) + int(VarIntLen(uint64(payloadLen))) + payloadLen + conn.CryptoState(qt.EncryptionLevelInitial).Write.Overhead())
for i := 0; i < paddingLength; i++ {
initialPacket.Frames = append(initialPacket.Frames, new(qt.PaddingFrame))
}
Expand Down

0 comments on commit 1da4265

Please sign in to comment.