Skip to content

Commit

Permalink
remove context constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
kevmo314 committed Jul 20, 2024
1 parent dd8d722 commit 4073924
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 122 deletions.
21 changes: 0 additions & 21 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package dtls

import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
Expand Down Expand Up @@ -118,15 +117,6 @@ type Config struct {

LoggerFactory logging.LoggerFactory

// ConnectContextMaker is a function to make a context used in Dial(),
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
// is used. It can be implemented as following.
//
// func ConnectContextMaker() (context.Context, func()) {
// return context.WithTimeout(context.Background(), 30*time.Second)
// }
ConnectContextMaker func() (context.Context, func())

// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int
Expand Down Expand Up @@ -230,17 +220,6 @@ type Config struct {
OnConnectionAttempt func(net.Addr) error
}

func defaultConnectContextMaker() (context.Context, func()) {
return context.WithTimeout(context.Background(), 30*time.Second)
}

func (c *Config) connectContextMaker() (context.Context, func()) {
if c.ConnectContextMaker == nil {
return defaultConnectContextMaker()
}
return c.ConnectContextMaker()
}

func (c *Config) includeCertificateSuites() bool {
return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil
}
Expand Down
81 changes: 45 additions & 36 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,28 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
return c, nil
}

// Handshake runs the client or server DTLS handshake
// protocol if it has not yet been run.
//
// Most uses of this package need not call Handshake explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
// [Conn.HandshakeContext].
func (c *Conn) Handshake() error {
return c.HandshakeContext(context.Background())
}

// HandshakeContext runs the client or server DTLS handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
//
// Most uses of this package need not call HandshakeContext explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
func (c *Conn) HandshakeContext(ctx context.Context) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
Expand Down Expand Up @@ -279,38 +297,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
}

// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return DialWithContext(ctx, network, rAddr, config)
}

// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ClientWithContext(ctx, conn, rAddr, config)
}

// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ServerWithContext(ctx, conn, rAddr, config)
}

// DialWithContext connects to the given network address and establishes a DTLS
// connection on top.
func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
// net.ListenUDP is used rather than net.DialUDP as the latter prevents the
// use of net.PacketConn.WriteTo.
// https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115
Expand All @@ -319,11 +306,11 @@ func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, co
return nil, err
}

return ClientWithContext(ctx, pConn, rAddr, config)
return Client(pConn, rAddr, config)
}

// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
// Client establishes a DTLS connection over an existing connection.
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
Expand All @@ -334,8 +321,8 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return createConn(conn, rAddr, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
// Server listens for incoming DTLS connections.
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}
Expand All @@ -347,6 +334,28 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return createConn(conn, rAddr, config, false, nil)
}

// DialWithContext connects to the given network address and establishes a DTLS
// connection on top.
//
// Deprecated: Use Dial instead, the context parameter is no longer used.
func DialWithContext(_ context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
return Dial(network, rAddr, config)
}

// ClientWithContext establishes a DTLS connection over an existing connection.
//
// Deprecated: Use Client instead, the context parameter is no longer used.
func ClientWithContext(_ context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
return Client(conn, rAddr, config)
}

// ServerWithContext listens for incoming DTLS connections.
//
// Deprecated: Use Server instead, the context parameter is no longer used.
func ServerWithContext(_ context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
return Server(conn, rAddr, config)
}

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
if err := c.Handshake(); err != nil {
Expand Down
3 changes: 0 additions & 3 deletions conn_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ func TestContextConfig(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
config := &Config{
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 40*time.Millisecond)
},
Certificates: []tls.Certificate{cert},
}

Expand Down
11 changes: 7 additions & 4 deletions e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func clientPion(c *comm) {
c.clientMutex.Lock()
defer c.clientMutex.Unlock()

conn, err := dtls.DialWithContext(c.ctx, "udp",
conn, err := dtls.Dial("udp",
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
c.clientConfig,
)
Expand Down Expand Up @@ -260,9 +260,12 @@ func serverPion(c *comm) {
return
}

if err := (c.serverConn.(*dtls.Conn)).HandshakeContext(c.ctx); err != nil {
c.errChan <- err
return
dtlsConn, ok := c.serverConn.(*dtls.Conn)
if ok {
if err := dtlsConn.HandshakeContext(c.ctx); err != nil {
c.errChan <- err
return
}
}

simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
Expand Down
7 changes: 6 additions & 1 deletion examples/dial/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ func main() {
// Connect to a DTLS server
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
dtlsConn, err := dtls.Dial("udp", addr, config)
util.Check(err)
defer func() {
util.Check(dtlsConn.Close())
}()

if err := dtlsConn.HandshakeContext(ctx); err != nil {
fmt.Printf("Failed to handshake with server: %v\n", err)
return
}

fmt.Println("Connected; type 'exit' to shutdown gracefully")

// Simulate a chat session
Expand Down
7 changes: 6 additions & 1 deletion examples/dial/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ func main() {
// Connect to a DTLS server
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
dtlsConn, err := dtls.Dial("udp", addr, config)
util.Check(err)
defer func() {
util.Check(dtlsConn.Close())
}()

if err := dtlsConn.HandshakeContext(ctx); err != nil {
fmt.Printf("Failed to handshake with server: %v\n", err)
return
}

fmt.Println("Connected; type 'exit' to shutdown gracefully")

// Simulate a chat session
Expand Down
7 changes: 6 additions & 1 deletion examples/dial/selfsign/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ func main() {
// Connect to a DTLS server
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
dtlsConn, err := dtls.Dial("udp", addr, config)
util.Check(err)
defer func() {
util.Check(dtlsConn.Close())
}()

if err := dtlsConn.HandshakeContext(ctx); err != nil {
fmt.Printf("Failed to handshake with server: %v\n", err)
return
}

fmt.Println("Connected; type 'exit' to shutdown gracefully")

// Simulate a chat session
Expand Down
7 changes: 6 additions & 1 deletion examples/dial/verify/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@ func main() {
// Connect to a DTLS server
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
dtlsConn, err := dtls.Dial("udp", addr, config)
util.Check(err)
defer func() {
util.Check(dtlsConn.Close())
}()

if err := dtlsConn.HandshakeContext(ctx); err != nil {
fmt.Printf("Failed to handshake with server: %v\n", err)
return
}

fmt.Println("Connected; type 'exit' to shutdown gracefully")

// Simulate a chat session
Expand Down
16 changes: 3 additions & 13 deletions examples/listen/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
package main

import (
"context"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/examples/util"
Expand All @@ -18,10 +16,6 @@ func main() {
// Prepare the IP to connect to
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}

// Create parent context to cleanup handshaking connections on exit.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

//
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//
Expand All @@ -32,13 +26,9 @@ func main() {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
// Create timeout context for accepted connection.
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
ConnectionIDGenerator: dtls.RandomCIDGenerator(8),
}

Expand Down
10 changes: 0 additions & 10 deletions examples/listen/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
package main

import (
"context"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/examples/util"
Expand All @@ -18,10 +16,6 @@ func main() {
// Prepare the IP to connect to
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}

// Create parent context to cleanup handshaking connections on exit.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

//
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//
Expand All @@ -35,10 +29,6 @@ func main() {
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
// Create timeout context for accepted connection.
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
}

// Connect to a DTLS server
Expand Down
10 changes: 0 additions & 10 deletions examples/listen/selfsign/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
package main

import (
"context"
"crypto/tls"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/examples/util"
Expand All @@ -24,10 +22,6 @@ func main() {
certificate, genErr := selfsign.GenerateSelfSigned()
util.Check(genErr)

// Create parent context to cleanup handshaking connections on exit.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

//
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//
Expand All @@ -36,10 +30,6 @@ func main() {
config := &dtls.Config{
Certificates: []tls.Certificate{certificate},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
// Create timeout context for accepted connection.
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
}

// Connect to a DTLS server
Expand Down
Loading

0 comments on commit 4073924

Please sign in to comment.