diff --git a/conn_test.go b/conn_test.go index 57fd5cb30..c3d6549b8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/hex" + "encoding/pem" "errors" "fmt" "io" @@ -2486,6 +2487,269 @@ func TestSessionResume(t *testing.T) { } _ = res.c.Close() }) + + t.Run("resumed client cert", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientRes := make(chan result, 1) + + commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com") + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]})) + + ss := &memSessStore{} + + id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") + secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") + + s := Session{ID: id, Secret: secret} + + ca, cb := dpipe.Pipe() + + _ = ss.Set(id, s) + _ = ss.Set([]byte(ca.RemoteAddr().String()+"_"+commonCert.Leaf.Subject.CommonName), s) + + go func() { + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + ServerName: commonCert.Leaf.Subject.CommonName, + SessionStore: ss, + RootCAs: certPool, + Certificates: nil, // Client shouldn't need to send a cert to resume a session + MTU: 100, + } + c, err := ClientWithContext(ctx, ca, config) + clientRes <- result{c, err} + }() + + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SessionStore: ss, + ClientCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + MTU: 100, + ClientAuth: RequireAndVerifyClientCert, + } + server, err := testServer(ctx, cb, config, true) + if err != nil { + t.Fatalf("TestSessionResume: Server failed(%v)", err) + } + + actualSessionID := server.ConnectionState().SessionID + actualMasterSecret := server.ConnectionState().masterSecret + if !bytes.Equal(actualSessionID, id) { + t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) + } + if !bytes.Equal(actualMasterSecret, secret) { + t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret) + } + + defer func() { + _ = server.Close() + }() + + res := <-clientRes + if res.err != nil { + t.Fatal(res.err) + } + _ = res.c.Close() + }) + + t.Run("new session client cert", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientRes := make(chan result, 1) + + commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com") + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]})) + + s1 := &memSessStore{} + s2 := &memSessStore{} + + ca, cb := dpipe.Pipe() + go func() { + config := &Config{ + ServerName: commonCert.Leaf.Subject.CommonName, + SessionStore: s1, + RootCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + } + c, err := ClientWithContext(ctx, ca, config) + clientRes <- result{c, err} + }() + + config := &Config{ + SessionStore: s2, + ClientAuth: RequireAndVerifyClientCert, + ClientCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + } + server, err := testServer(ctx, cb, config, false) + if err != nil { + t.Fatalf("TestSessionResumetion: Server failed(%v)", err) + } + + actualSessionID := server.ConnectionState().SessionID + actualMasterSecret := server.ConnectionState().masterSecret + ss, _ := s2.Get(actualSessionID) + if !bytes.Equal(actualMasterSecret, ss.Secret) { + t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected/actual:\n(%v)\n(%v)", ss.Secret, actualMasterSecret) + } + + if ss.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() { + t.Errorf("TestSessionResumption: expected server session store to contain certificate expiry") + } + + defer func() { + _ = server.Close() + }() + + res := <-clientRes + if res.err != nil { + t.Fatal(res.err) + } + cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_" + commonCert.Leaf.Subject.CommonName)) + if !bytes.Equal(actualMasterSecret, cs.Secret) { + t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected/actual\n(%v)\n(%v)", cs.Secret, actualMasterSecret) + } + + if cs.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() { + t.Errorf("TestSessionResumption: expected client session store to contain certificate expiry") + } + + _ = res.c.Close() + }) + + t.Run("expire client cert session", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientRes := make(chan result, 1) + + commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com") + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]})) + + ss := &memSessStore{} + + id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") + secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") + + oldClientSessionTime := time.Now().Add(time.Hour) + clientSession := Session{ + ID: id, + Secret: secret, + Expiry: oldClientSessionTime, + } + + expiredServerSession := Session{ + ID: id, + Secret: secret, + Expiry: time.Now().Add(-time.Hour), // server should treat this as expired session and force a new cert verification + } + + ca, cb := dpipe.Pipe() + + _ = ss.Set(id, expiredServerSession) + _ = ss.Set([]byte(ca.RemoteAddr().String()+"_"+commonCert.Leaf.Subject.CommonName), clientSession) + + go func() { + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + ServerName: commonCert.Leaf.Subject.CommonName, + SessionStore: ss, + RootCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + MTU: 1200, // MTU must be able to fit cert chain in one packet + } + c, err := ClientWithContext(ctx, ca, config) + clientRes <- result{c, err} + }() + + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SessionStore: ss, + ClientCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + MTU: 1200, + ClientAuth: RequireAndVerifyClientCert, + } + server, err := testServer(ctx, cb, config, false) + if err != nil { + t.Fatalf("TestSessionResume: Server failed(%v)", err) + } + + actualSessionID := server.ConnectionState().SessionID + actualMasterSecret := server.ConnectionState().masterSecret + if bytes.Equal(actualSessionID, id) { + t.Errorf("TestSessionResumption: SessionID Mismatch: expected new session ID(%v) actual(%v)", id, actualSessionID) + } + + if bytes.Equal(actualMasterSecret, secret) { + t.Errorf("TestSessionResumption: masterSecret Mismatch: expected new master secret (%v) actual(%v)", secret, actualMasterSecret) + } + + defer func() { + _ = server.Close() + }() + + res := <-clientRes + if res.err != nil { + t.Fatal(res.err) + } + + _, ok := ss.Map.Load(hex.EncodeToString(expiredServerSession.ID)) + if ok { + t.Errorf("expected server to have deleted session") + } + + cSess, ok := ss.Map.Load(hex.EncodeToString([]byte(ca.RemoteAddr().String() + "_" + commonCert.Leaf.Subject.CommonName))) + if !ok { + t.Errorf("expected client store to have cached new session ID") + } + + newClientSession := cSess.(Session) + if bytes.Equal(secret, newClientSession.Secret) { + t.Errorf("expected : expected client session store to contain new master secret (%v) actual(%v)", secret, newClientSession.Secret) + } + + if newClientSession.Expiry.Unix() == oldClientSessionTime.Unix() { + t.Errorf("expected new client session to have updated") + } + + if newClientSession.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() { + t.Errorf("expected new client session to expire with client cert") + } + + sSess, ok := ss.Map.Load(hex.EncodeToString(newClientSession.ID)) + if !ok { + t.Errorf("expected server store to have cached new client session ID") + } + newServerSession := sSess.(Session) + + if !bytes.Equal(newServerSession.Secret, newClientSession.Secret) { + t.Errorf("expected : expected session store to contain new shared secret (%v) actual(%v)", newServerSession.Secret, newClientSession.Secret) + } + _ = res.c.Close() + }) } type memSessStore struct { @@ -2507,7 +2771,16 @@ func (ms *memSessStore) Get(key []byte) (Session, error) { return Session{}, nil } - return v.(Session), nil + session := v.(Session) + if session.Expiry.IsZero() { + return session, nil + } + + if time.Now().After(session.Expiry) { + _ = ms.Del(key) + return Session{}, nil + } + return session, nil } func (ms *memSessStore) Del(key []byte) error {