diff --git a/agents/parse_agent.go b/agents/parse_agent.go index 289ce06..0a808a8 100644 --- a/agents/parse_agent.go +++ b/agents/parse_agent.go @@ -2,7 +2,6 @@ package agents import ( . "github.com/QUIC-Tracker/quic-tracker" - . "github.com/QUIC-Tracker/quic-tracker/lib" "unsafe" "bytes" ) @@ -78,9 +77,9 @@ func (a *ParsingAgent) Run(conn *Connection) { break packetSelect } - payload, err := cryptoState.Read.Open(nil, EncodeArgs(header.PacketNumber()), ciphertext[hLen:hLen+pLen], ciphertext[:hLen]) - if err != nil { - a.Logger.Printf("Could not decrypt packet {type=%s, number=%d}: %s\n", header.PacketType().String(), header.PacketNumber(), err.Error()) + payload := cryptoState.Read.Decrypt(ciphertext[hLen:hLen+pLen], uint64(header.PacketNumber()), ciphertext[:hLen]) + if payload == nil { + a.Logger.Printf("Could not decrypt packet {type=%s, number=%d}\n", header.PacketType().String(), header.PacketNumber()) break packetSelect } @@ -94,9 +93,9 @@ func (a *ParsingAgent) Run(conn *Connection) { off += hLen + pLen case ShortHeaderPacket: // Packets with a short header always include a 1-RTT protected payload. - payload, err := cryptoState.Read.Open(nil, EncodeArgs(header.PacketNumber()), ciphertext[hLen:], ciphertext[:hLen]) - if err != nil { - a.Logger.Printf("Could not decrypt packet {type=%s, number=%d}: %s\n", header.PacketType().String(), header.PacketNumber(), err.Error()) + payload := cryptoState.Read.Decrypt(ciphertext[hLen:], uint64(header.PacketNumber()), ciphertext[:hLen]) + if payload == nil { + a.Logger.Printf("Could not decrypt packet {type=%s, number=%d}\n", header.PacketType().String(), header.PacketNumber()) break packetSelect } cleartext = append(append(cleartext, udpPayload[off:off+hLen]...), payload...) diff --git a/connection.go b/connection.go index 18d9874..255c857 100644 --- a/connection.go +++ b/connection.go @@ -77,7 +77,7 @@ func (c *Connection) SendPacket(packet Packet, level EncryptionLevel) { } header := packet.EncodeHeader() - protectedPayload := cryptoState.Write.Seal(nil, EncodeArgs(packet.Header().PacketNumber()), payload, header) + protectedPayload := cryptoState.Write.Encrypt(payload, uint64(packet.Header().PacketNumber()), header) packetBytes := append(header, protectedPayload...) sample, pnOffset := GetPacketSample(packet.Header(), packetBytes) diff --git a/crypto.go b/crypto.go index 74e0a90..bcbe0a2 100644 --- a/crypto.go +++ b/crypto.go @@ -1,9 +1,7 @@ package quictracker import ( - "crypto/cipher" "github.com/mpiraux/pigotls" - . "github.com/QUIC-Tracker/quic-tracker/lib" ) var quicVersionSalt = []byte{ // See https://tools.ietf.org/html/draft-ietf-quic-tls-10#section-5.2.2 @@ -62,19 +60,19 @@ type DirectionalEncryptionLevel struct { } type CryptoState struct { - Read cipher.AEAD - Write cipher.AEAD + Read *pigotls.AEAD + Write *pigotls.AEAD PacketRead *pigotls.Cipher PacketWrite *pigotls.Cipher } func (s *CryptoState) InitRead(tls *pigotls.Connection, readSecret []byte) { - s.Read = newProtectedAead(tls, readSecret) + s.Read = tls.NewAEAD(readSecret, false) s.PacketRead = tls.NewCipher(tls.HkdfExpandLabel(readSecret, "pn", nil, tls.AEADKeySize())) } func (s *CryptoState) InitWrite(tls *pigotls.Connection, writeSecret []byte) { - s.Write = newProtectedAead(tls, writeSecret) + s.Write = tls.NewAEAD(writeSecret, true) s.PacketWrite = tls.NewCipher(tls.HkdfExpandLabel(writeSecret, "pn", nil, tls.AEADKeySize())) } @@ -96,17 +94,6 @@ func NewProtectedCryptoState(tls *pigotls.Connection, readSecret []byte, writeSe return s } -func newProtectedAead(tls *pigotls.Connection, secret []byte) cipher.AEAD { - k := tls.HkdfExpandLabel(secret, "key", nil, tls.AEADKeySize()) - iv := tls.HkdfExpandLabel(secret, "iv", nil, tls.AEADIvSize()) - - aead, err := NewWrappedAESGCM(k, iv) - if err != nil { - panic(err) - } - return aead -} - func GetPacketSample(header Header, packetBytes []byte) ([]byte, int) { var pnOffset int sampleLength := 16