Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SetSendTimeout/SetReceiveTimeout #1012

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 12 additions & 29 deletions handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"

"github.com/vishvananda/netlink/nl"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
)

func TestHandleCreateClose(t *testing.T) {
Expand Down Expand Up @@ -122,13 +120,22 @@ func TestHandleTimeout(t *testing.T) {
defer h.Close()

for _, sh := range h.sockets {
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0})
verifySockTimeVal(t, sh.Socket, time.Duration(0))
}

h.SetSocketTimeout(2*time.Second + 8*time.Millisecond)
const timeout = 2*time.Second + 8*time.Millisecond
h.SetSocketTimeout(timeout)

for _, sh := range h.sockets {
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000})
verifySockTimeVal(t, sh.Socket, timeout)
}
}

func verifySockTimeVal(t *testing.T, socket *nl.NetlinkSocket, expTimeout time.Duration) {
t.Helper()
send, receive := socket.GetTimeouts()
if send != expTimeout || receive != expTimeout {
t.Fatalf("Expected timeout: %v, got Send: %v, Receive: %v", expTimeout, send, receive)
}
}

Expand Down Expand Up @@ -157,30 +164,6 @@ func TestHandleReceiveBuffer(t *testing.T) {
}
}

func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) {
var (
tr unix.Timeval
v = uint32(0x10)
)
_, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
if errno != 0 {
t.Fatal(errno)
}

if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
}

_, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
if errno != 0 {
t.Fatal(errno)
}

if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
}
}

var (
iter = 10
numThread = uint32(4)
Expand Down
73 changes: 62 additions & 11 deletions nl/nl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ package nl
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"os"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"

"github.com/vishvananda/netns"
Expand Down Expand Up @@ -656,9 +658,11 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
}

type NetlinkSocket struct {
fd int32
file *os.File
lsa unix.SockaddrNetlink
fd int32
file *os.File
lsa unix.SockaddrNetlink
sendTimeout int64 // Access using atomic.Load/StoreInt64
receiveTimeout int64 // Access using atomic.Load/StoreInt64
sync.Mutex
}

Expand Down Expand Up @@ -802,8 +806,44 @@ func (s *NetlinkSocket) GetFd() int {
return int(s.fd)
}

func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) {
return time.Duration(atomic.LoadInt64(&s.sendTimeout)),
time.Duration(atomic.LoadInt64(&s.receiveTimeout))
}

func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
rawConn, err := s.file.SyscallConn()
if err != nil {
return err
}
var (
deadline time.Time
innerErr error
)
sendTimeout := atomic.LoadInt64(&s.sendTimeout)
if sendTimeout != 0 {
deadline = time.Now().Add(time.Duration(sendTimeout))
}
if err := s.file.SetWriteDeadline(deadline); err != nil {
return err
}
serializedReq := request.Serialize()
err = rawConn.Write(func(fd uintptr) (done bool) {
innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa)
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
return innerErr
}
if err != nil {
// The timeout was previously implemented using SO_SNDTIMEO on a blocking
// socket. So, continue to return EAGAIN when the timeout is reached.
if errors.Is(err, os.ErrDeadlineExceeded) {
return unix.EAGAIN
}
return err
}
return nil
}

func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
Expand All @@ -812,20 +852,33 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli
return nil, nil, err
}
var (
deadline time.Time
fromAddr *unix.SockaddrNetlink
rb [RECEIVE_BUFFER_SIZE]byte
nr int
from unix.Sockaddr
innerErr error
)
receiveTimeout := atomic.LoadInt64(&s.receiveTimeout)
if receiveTimeout != 0 {
deadline = time.Now().Add(time.Duration(receiveTimeout))
}
if err := s.file.SetReadDeadline(deadline); err != nil {
return nil, nil, err
}
err = rawConn.Read(func(fd uintptr) (done bool) {
nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
return nil, nil, innerErr
}
if err != nil {
// The timeout was previously implemented using SO_RCVTIMEO on a blocking
// socket. So, continue to return EAGAIN when the timeout is reached.
if errors.Is(err, os.ErrDeadlineExceeded) {
return nil, nil, unix.EAGAIN
}
return nil, nil, err
}
fromAddr, ok := from.(*unix.SockaddrNetlink)
Expand All @@ -847,16 +900,14 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli

// SetSendTimeout allows to set a send timeout on the socket
func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
// Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine
// remains stuck on a send on a closed fd
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout)
atomic.StoreInt64(&s.sendTimeout, timeout.Nano())
return nil
}

// SetReceiveTimeout allows to set a receive timeout on the socket
func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
// Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine
// remains stuck on a recvmsg on a closed fd
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout)
atomic.StoreInt64(&s.receiveTimeout, timeout.Nano())
return nil
}

// SetReceiveBufferSize allows to set a receive buffer size on the socket
Expand Down
63 changes: 63 additions & 0 deletions nl/nl_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,69 @@ func TestIfSocketCloses(t *testing.T) {
}
}

func TestReceiveTimeout(t *testing.T) {
nlSock, err := getNetlinkSocket(unix.NETLINK_ROUTE)
if err != nil {
t.Fatalf("Error creating the socket: %v", err)
}
// Even if the test fails because the timeout doesn't work, closing the
// socket at the end of the test should result in an EAGAIN (as long as
// TestIfSocketCloses completed, otherwise this test will leak the
// goroutines running the Receive).
defer nlSock.Close()
const failAfter = time.Second

tests := []struct {
name string
timeout time.Duration
}{
{
name: "1us timeout", // The smallest value accepted by Handle.SetSocketTimeout
timeout: time.Microsecond,
},
{
name: "100ms timeout",
timeout: 100 * time.Millisecond,
},
{
name: "500ms timeout",
timeout: 500 * time.Millisecond,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
timeout := unix.NsecToTimeval(int64(tc.timeout))
nlSock.SetReceiveTimeout(&timeout)

doneC := make(chan time.Duration)
errC := make(chan error)
go func() {
start := time.Now()
_, _, err := nlSock.Receive()
dur := time.Since(start)
if err != unix.EAGAIN {
errC <- err
return
}
doneC <- dur
}()

failTimerC := time.After(failAfter)
select {
case dur := <-doneC:
if dur < tc.timeout || dur > (tc.timeout+(100*time.Millisecond)) {
t.Fatalf("Expected timeout %v got %v", tc.timeout, dur)
}
case err := <-errC:
t.Fatalf("Expected EAGAIN, but got: %v", err)
case <-failTimerC:
t.Fatalf("No timeout received")
}
})
}
}

func (msg *CnMsgOp) write(b []byte) {
native := NativeEndian()
native.PutUint32(b[0:4], msg.ID.Idx)
Expand Down
Loading