From f99cd6a4b7c8064f8e22742e32e43f2609a27682 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sat, 20 Jul 2024 00:35:39 -0400 Subject: [PATCH] Perform handshake on first read/write Updates the connection to perform a handshake on first read/write instead of on accept. Closes https://github.com/pion/dtls/issues/279. --- bench_test.go | 2 +- config.go | 21 -- conn.go | 195 +++++++++--------- conn_go_test.go | 45 +++- conn_test.go | 42 ++-- e2e/e2e_test.go | 18 +- examples/dial/cid/main.go | 7 +- examples/dial/psk/main.go | 7 +- examples/dial/selfsign/main.go | 7 +- examples/dial/verify/main.go | 7 +- examples/listen/cid/main.go | 23 +-- examples/listen/psk/main.go | 17 +- examples/listen/selfsign/main.go | 17 +- .../verify-brute-force-protection/main.go | 16 +- examples/listen/verify/main.go | 17 +- handshaker.go | 2 + listener.go | 2 - resume.go | 12 +- 18 files changed, 242 insertions(+), 215 deletions(-) diff --git a/bench_test.go b/bench_test.go index 7b236f6d8..885b311f2 100644 --- a/bench_test.go +++ b/bench_test.go @@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) { return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} diff --git a/config.go b/config.go index 0f6813d24..f97930294 100644 --- a/config.go +++ b/config.go @@ -4,7 +4,6 @@ package dtls import ( - "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -118,15 +117,6 @@ type Config struct { LoggerFactory logging.LoggerFactory - // ConnectContextMaker is a function to make a context used in Dial(), - // Client(), Server(), and Accept(). If nil, the default ConnectContextMaker - // is used. It can be implemented as following. - // - // func ConnectContextMaker() (context.Context, func()) { - // return context.WithTimeout(context.Background(), 30*time.Second) - // } - ConnectContextMaker func() (context.Context, func()) - // MTU is the length at which handshake messages will be fragmented to // fit within the maximum transmission unit (default is 1200 bytes) MTU int @@ -230,17 +220,6 @@ type Config struct { OnConnectionAttempt func(net.Addr) error } -func defaultConnectContextMaker() (context.Context, func()) { - return context.WithTimeout(context.Background(), 30*time.Second) -} - -func (c *Config) connectContextMaker() (context.Context, func()) { - if c.ConnectContextMaker == nil { - return defaultConnectContextMaker() - } - return c.ConnectContextMaker() -} - func (c *Config) includeCertificateSuites() bool { return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil } diff --git a/conn.go b/conn.go index 0ca5090d2..0d8896550 100644 --- a/conn.go +++ b/conn.go @@ -73,6 +73,7 @@ type Conn struct { paddingLengthGenerator func(uint) uint handshakeCompletedSuccessfully atomic.Value + handshakeMutex sync.Mutex encryptedPackets []addrPkt @@ -94,9 +95,11 @@ type Conn struct { fsm *handshakeFSM replayProtectionWindow uint + + handshakeConfig *handshakeConfig } -func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) { +func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -127,42 +130,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien paddingLengthGenerator = func(uint) uint { return 0 } } - c := &Conn{ - rAddr: rAddr, - nextConn: netctx.NewPacketConn(nextConn), - fragmentBuffer: newFragmentBuffer(), - handshakeCache: newHandshakeCache(), - maximumTransmissionUnit: mtu, - paddingLengthGenerator: paddingLengthGenerator, - - decrypted: make(chan interface{}, 1), - log: logger, - - readDeadline: deadline.New(), - writeDeadline: deadline.New(), - - reading: make(chan struct{}, 1), - handshakeRecv: make(chan recvHandshakeState), - closed: closer.NewCloser(), - cancelHandshaker: func() {}, - - replayProtectionWindow: uint(replayProtectionWindow), - - state: State{ - isClient: isClient, - }, - } - - c.setRemoteEpoch(0) - c.setLocalEpoch(0) - return c, nil -} - -func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { - if conn == nil { - return nil, errNilNextConn - } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) if err != nil { return nil, err @@ -190,7 +157,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo curves = defaultCurves } - hsCfg := &handshakeConfig{ + handshakeConfig := &handshakeConfig{ localPSKCallback: config.PSK, localPSKIdentityHint: config.PSKIdentityHint, localCipherSuites: cipherSuites, @@ -209,7 +176,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo customCipherSuites: config.CustomCipherSuites, initialRetransmitInterval: workerInterval, disableRetransmitBackoff: config.DisableRetransmitBackoff, - log: conn.log, + log: logger, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, @@ -222,33 +189,97 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo clientHelloMessageHook: config.ClientHelloMessageHook, serverHelloMessageHook: config.ServerHelloMessageHook, certificateRequestMessageHook: config.CertificateRequestMessageHook, + resumeState: resumeState, + } + + c := &Conn{ + rAddr: rAddr, + nextConn: netctx.NewPacketConn(nextConn), + handshakeConfig: handshakeConfig, + fragmentBuffer: newFragmentBuffer(), + handshakeCache: newHandshakeCache(), + maximumTransmissionUnit: mtu, + paddingLengthGenerator: paddingLengthGenerator, + + decrypted: make(chan interface{}, 1), + log: logger, + + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + + reading: make(chan struct{}, 1), + handshakeRecv: make(chan recvHandshakeState), + closed: closer.NewCloser(), + cancelHandshaker: func() {}, + cancelHandshakeReader: func() {}, + + replayProtectionWindow: uint(replayProtectionWindow), + + state: State{ + isClient: isClient, + }, + } + + c.setRemoteEpoch(0) + c.setLocalEpoch(0) + return c, nil +} + +// Handshake runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// [Conn.HandshakeContext]. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.isHandshakeCompletedSuccessfully() { + return nil } // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. - if !isClient { - cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) + if !c.state.isClient { + cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { - return nil, err + return err } - hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) + c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } var initialFlight flightVal var initialFSMState handshakeState - if initialState != nil { - if conn.state.isClient { + if c.handshakeConfig.resumeState != nil { + if c.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished - conn.state = *initialState + c.state = *c.handshakeConfig.resumeState } else { - if conn.state.isClient { + if c.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 @@ -256,48 +287,17 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo initialFSMState = handshakePreparing } // Do handshake - if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { - return nil, err + if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + return err } - conn.log.Trace("Handshake Completed") + c.log.Trace("Handshake Completed") - return conn, nil + return nil } // Dial connects to the given network address and establishes a DTLS connection on top. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use DialWithContext() instead. func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return DialWithContext(ctx, network, rAddr, config) -} - -// Client establishes a DTLS connection over an existing connection. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use ClientWithContext() instead. -func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return ClientWithContext(ctx, conn, rAddr, config) -} - -// Server listens for incoming DTLS connections. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use ServerWithContext() instead. -func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return ServerWithContext(ctx, conn, rAddr, config) -} - -// DialWithContext connects to the given network address and establishes a DTLS -// connection on top. -func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { // net.ListenUDP is used rather than net.DialUDP as the latter prevents the // use of net.PacketConn.WriteTo. // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 @@ -306,11 +306,11 @@ func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, co return nil, err } - return ClientWithContext(ctx, pConn, rAddr, config) + return Client(pConn, rAddr, config) } -// ClientWithContext establishes a DTLS connection over an existing connection. -func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { +// Client establishes a DTLS connection over an existing connection. +func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided @@ -318,16 +318,11 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, errPSKAndIdentityMustBeSetForClient } - dconn, err := createConn(conn, rAddr, config, true) - if err != nil { - return nil, err - } - - return handshakeConn(ctx, dconn, config, true, nil) + return createConn(conn, rAddr, config, true, nil) } -// ServerWithContext listens for incoming DTLS connections. -func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { +// Server listens for incoming DTLS connections. +func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } @@ -336,17 +331,13 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, err } } - dconn, err := createConn(conn, rAddr, config, false) - if err != nil { - return nil, err - } - return handshakeConn(ctx, dconn, config, false, nil) + return createConn(conn, rAddr, config, false, nil) } // Read reads data from the connection. func (c *Conn) Read(p []byte) (n int, err error) { - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } select { @@ -389,8 +380,8 @@ func (c *Conn) Write(p []byte) (int, error) { default: } - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } return len(p), c.writePackets(c.writeDeadline, []*packet{ diff --git a/conn_go_test.go b/conn_go_test.go index d9ca6e187..c79d1b70e 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -52,9 +52,6 @@ func TestContextConfig(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } config := &Config{ - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(context.Background(), 40*time.Millisecond) - }, Certificates: []tls.Certificate{cert}, } @@ -64,9 +61,15 @@ func TestContextConfig(t *testing.T) { }{ "Dial": { f: func() (func() (net.Conn, error), func()) { + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { - return Dial("udp", addr, config) + conn, err := Dial("udp", addr, config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { + cancel() } }, order: []byte{0, 1, 2}, @@ -75,7 +78,11 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) return func() (net.Conn, error) { - return DialWithContext(ctx, "udp", addr, config) + conn, err := DialWithContext(ctx, "udp", addr, config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() } @@ -85,10 +92,16 @@ func TestContextConfig(t *testing.T) { "Client": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { - return Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() + cancel() } }, order: []byte{0, 1, 2}, @@ -98,7 +111,11 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ClientWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := ClientWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() _ = ca.Close() @@ -109,10 +126,16 @@ func TestContextConfig(t *testing.T) { "Server": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { - return Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() + cancel() } }, order: []byte{0, 1, 2}, @@ -122,7 +145,11 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ServerWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := ServerWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() _ = ca.Close() diff --git a/conn_test.go b/conn_test.go index 3d4b3ab47..03364e2af 100644 --- a/conn_test.go +++ b/conn_test.go @@ -295,7 +295,11 @@ func testClient(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Conf cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true - return ClientWithContext(ctx, c, rAddr, cfg) + conn, err := ClientWithContext(ctx, c, rAddr, cfg) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) } func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool) (*Conn, error) { @@ -306,7 +310,11 @@ func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Conf } cfg.Certificates = []tls.Certificate{serverCert} } - return ServerWithContext(ctx, c, rAddr, cfg) + conn, err := ServerWithContext(ctx, c, rAddr, cfg) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) } func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { @@ -1135,17 +1143,18 @@ func TestClientCertificate(t *testing.T) { t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } c := make(chan result) go func() { client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) - c <- result{client, err} + c <- result{client, err, client.Handshake()} }() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) + hserr := server.Handshake() res := <-c defer func() { if err == nil { @@ -1157,7 +1166,7 @@ func TestClientCertificate(t *testing.T) { }() if tt.wantErr { - if err != nil { + if err != nil || hserr != nil { // Error expected, test succeeded return } @@ -1556,23 +1565,24 @@ func TestServerCertificate(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } srvCh := make(chan result) go func() { s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) - srvCh <- result{s, err} + srvCh <- result{s, err, s.Handshake()} }() cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) + hserr := cli.Handshake() if err == nil { _ = cli.Close() } - if !tt.wantErr && err != nil { - t.Errorf("Client failed(%v)", err) + if !tt.wantErr && (err != nil || hserr != nil) { + t.Errorf("Client failed(%v, %v)", err, hserr) } - if tt.wantErr && err == nil { + if tt.wantErr && err == nil && hserr == nil { t.Fatal("Error expected") } @@ -3237,7 +3247,7 @@ func TestSkipHelloVerify(t *testing.T) { return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} @@ -3306,7 +3316,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { cfg := &Config{} cfg.Certificates = []tls.Certificate{serverCert} - dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false) + dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) if err != nil { t.Error(err) return @@ -3322,7 +3332,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { time.Sleep(1 * time.Second) } }() - if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil { + if err := dconn.HandshakeContext(ctx); err == nil { t.Error("expected handshake to fail") } close(done) @@ -3402,7 +3412,7 @@ func TestHelloRandom(t *testing.T) { return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index ec1253ec8..0d8ba35bc 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -219,8 +219,7 @@ func clientPion(c *comm) { c.clientMutex.Lock() defer c.clientMutex.Unlock() - var err error - c.clientConn, err = dtls.DialWithContext(c.ctx, "udp", + conn, err := dtls.Dial("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientConfig, ) @@ -229,6 +228,13 @@ func clientPion(c *comm) { return } + if err := conn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return + } + + c.clientConn = conn + simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- nil close(c.clientDone) @@ -254,6 +260,14 @@ func serverPion(c *comm) { return } + dtlsConn, ok := c.serverConn.(*dtls.Conn) + if ok { + if err := dtlsConn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return + } + } + simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- nil close(c.serverDone) diff --git a/examples/dial/cid/main.go b/examples/dial/cid/main.go index 10e547706..4859e72b1 100644 --- a/examples/dial/cid/main.go +++ b/examples/dial/cid/main.go @@ -37,12 +37,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/psk/main.go b/examples/dial/psk/main.go index b70efdcc3..94731e93a 100644 --- a/examples/dial/psk/main.go +++ b/examples/dial/psk/main.go @@ -36,12 +36,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/selfsign/main.go b/examples/dial/selfsign/main.go index 5fa25a923..b3be5a8f8 100644 --- a/examples/dial/selfsign/main.go +++ b/examples/dial/selfsign/main.go @@ -38,12 +38,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/verify/main.go b/examples/dial/verify/main.go index 07501954d..ed5352dd5 100644 --- a/examples/dial/verify/main.go +++ b/examples/dial/verify/main.go @@ -45,12 +45,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go index 770bbcfa4..11be4bc21 100644 --- a/examples/listen/cid/main.go +++ b/examples/listen/cid/main.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -18,10 +17,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -32,13 +27,9 @@ func main() { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, - PSKIdentityHint: []byte("Pion DTLS Server"), - CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, - ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, + PSKIdentityHint: []byte("Pion DTLS Server"), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ConnectionIDGenerator: dtls.RandomCIDGenerator(8), } @@ -65,6 +56,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/psk/main.go b/examples/listen/psk/main.go index 66f099693..e6c16514c 100644 --- a/examples/listen/psk/main.go +++ b/examples/listen/psk/main.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -18,10 +17,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -35,10 +30,6 @@ func main() { PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server @@ -64,6 +55,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/selfsign/main.go b/examples/listen/selfsign/main.go index 025b667e4..af1010102 100644 --- a/examples/listen/selfsign/main.go +++ b/examples/listen/selfsign/main.go @@ -9,7 +9,6 @@ import ( "crypto/tls" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -24,10 +23,6 @@ func main() { certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -36,10 +31,6 @@ func main() { config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server @@ -65,6 +56,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/verify-brute-force-protection/main.go b/examples/listen/verify-brute-force-protection/main.go index b5fb82c42..bcb518164 100644 --- a/examples/listen/verify-brute-force-protection/main.go +++ b/examples/listen/verify-brute-force-protection/main.go @@ -22,10 +22,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -52,10 +48,6 @@ func main() { ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, // This function will be called on each connection attempt. OnConnectionAttempt: func(addr net.Addr) error { // *************** Brute Force Attack protection *************** @@ -122,6 +114,14 @@ func main() { attemptsMutex.Unlock() // *************** END Brute Force Attack protection END *************** + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub hub.Register(conn) } diff --git a/examples/listen/verify/main.go b/examples/listen/verify/main.go index a02211e15..5ef6e038f 100644 --- a/examples/listen/verify/main.go +++ b/examples/listen/verify/main.go @@ -10,7 +10,6 @@ import ( "crypto/x509" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -20,10 +19,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -45,10 +40,6 @@ func main() { ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server @@ -74,6 +65,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub hub.Register(conn) } diff --git a/handshaker.go b/handshaker.go index a585c3db8..62a4bf6e8 100644 --- a/handshaker.go +++ b/handshaker.go @@ -132,6 +132,8 @@ type handshakeConfig struct { clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + resumeState *State } type flightConn interface { diff --git a/listener.go b/listener.go index 90dbbb427..cb75d4143 100644 --- a/listener.go +++ b/listener.go @@ -67,8 +67,6 @@ type listener struct { // Accept waits for and returns the next connection to the listener. // You have to either close or read on all connection that are created. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, set ConnectContextMaker. func (l *listener) Accept() (net.Conn, error) { c, raddr, err := l.parent.Accept() if err != nil { diff --git a/resume.go b/resume.go index 6cd1c5a69..0b76314a5 100644 --- a/resume.go +++ b/resume.go @@ -4,7 +4,6 @@ package dtls import ( - "context" "net" ) @@ -13,14 +12,5 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) ( if err := state.initCipherSuite(); err != nil { return nil, err } - dconn, err := createConn(conn, rAddr, config, state.isClient) - if err != nil { - return nil, err - } - c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state) - if err != nil { - return nil, err - } - - return c, nil + return createConn(conn, rAddr, config, state.isClient, state) }