From 42fff99bde5a87d99123d10b697ca1db697ad9bc Mon Sep 17 00:00:00 2001 From: Rob Murray Date: Fri, 30 Aug 2024 09:37:20 +0100 Subject: [PATCH] Fix SetSendTimeout/SetReceiveTimeout They were implemented using SO_SNDTIMEO/SO_RCVTIMEO on the socket descriptor - but that doesn't work now the socket is non-blocking. Instead, set deadlines on the file read/write. Signed-off-by: Rob Murray --- nl/nl_linux.go | 70 ++++++++++++++++++++++++++++++++++++++------- nl/nl_linux_test.go | 36 +++++++++++++++++++++++ 2 files changed, 95 insertions(+), 11 deletions(-) diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 6cecc451..f05c1c09 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -4,6 +4,7 @@ package nl import ( "bytes" "encoding/binary" + "errors" "fmt" "net" "os" @@ -11,6 +12,7 @@ import ( "sync" "sync/atomic" "syscall" + "time" "unsafe" "github.com/vishvananda/netns" @@ -656,9 +658,11 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest { } type NetlinkSocket struct { - fd int32 - file *os.File - lsa unix.SockaddrNetlink + fd int32 + file *os.File + lsa unix.SockaddrNetlink + sendTimeout int64 // Access using atomic.Load/StoreInt64 + receiveTimeout int64 // Access using atomic.Load/StoreInt64 sync.Mutex } @@ -803,7 +807,38 @@ func (s *NetlinkSocket) GetFd() int { } func (s *NetlinkSocket) Send(request *NetlinkRequest) error { - return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa) + rawConn, err := s.file.SyscallConn() + if err != nil { + return err + } + var ( + deadline time.Time + innerErr error + ) + sendTimeout := atomic.LoadInt64(&s.sendTimeout) + if sendTimeout != 0 { + deadline = time.Now().Add(time.Duration(sendTimeout)) + } + if err := s.file.SetWriteDeadline(deadline); err != nil { + return err + } + serializedReq := request.Serialize() + err = rawConn.Write(func(fd uintptr) (done bool) { + innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa) + return innerErr != unix.EWOULDBLOCK + }) + if innerErr != nil { + return innerErr + } + if err != nil { + // The timeout was previously implemented using SO_SNDTIMEO on a blocking + // socket. So, continue to return EAGAIN when the timeout is reached. + if errors.Is(err, os.ErrDeadlineExceeded) { + return unix.EAGAIN + } + return err + } + return nil } func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) { @@ -812,20 +847,33 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli return nil, nil, err } var ( + deadline time.Time fromAddr *unix.SockaddrNetlink rb [RECEIVE_BUFFER_SIZE]byte nr int from unix.Sockaddr innerErr error ) + receiveTimeout := atomic.LoadInt64(&s.receiveTimeout) + if receiveTimeout != 0 { + deadline = time.Now().Add(time.Duration(receiveTimeout)) + } + if err := s.file.SetReadDeadline(deadline); err != nil { + return nil, nil, err + } err = rawConn.Read(func(fd uintptr) (done bool) { nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0) return innerErr != unix.EWOULDBLOCK }) if innerErr != nil { - err = innerErr + return nil, nil, innerErr } if err != nil { + // The timeout was previously implemented using SO_RCVTIMEO on a blocking + // socket. So, continue to return EAGAIN when the timeout is reached. + if errors.Is(err, os.ErrDeadlineExceeded) { + return nil, nil, unix.EAGAIN + } return nil, nil, err } fromAddr, ok := from.(*unix.SockaddrNetlink) @@ -847,16 +895,16 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli // SetSendTimeout allows to set a send timeout on the socket func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error { - // Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine - // remains stuck on a send on a closed fd - return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout) + duration := (time.Duration(timeout.Sec) * time.Second) + (time.Duration(timeout.Usec) * time.Microsecond) + atomic.StoreInt64(&s.sendTimeout, int64(duration)) + return nil } // SetReceiveTimeout allows to set a receive timeout on the socket func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error { - // Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine - // remains stuck on a recvmsg on a closed fd - return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout) + duration := (time.Duration(timeout.Sec) * time.Second) + (time.Duration(timeout.Usec) * time.Microsecond) + atomic.StoreInt64(&s.receiveTimeout, int64(duration)) + return nil } // SetReceiveBufferSize allows to set a receive buffer size on the socket diff --git a/nl/nl_linux_test.go b/nl/nl_linux_test.go index 96de8d5e..be73cf4d 100644 --- a/nl/nl_linux_test.go +++ b/nl/nl_linux_test.go @@ -97,6 +97,42 @@ func TestIfSocketCloses(t *testing.T) { } } +func TestReceiveTimeout(t *testing.T) { + nlSock, err := Subscribe(unix.NETLINK_ROUTE, unix.RTNLGRP_NEIGH) + if err != nil { + t.Fatalf("Error creating the socket: %v", err) + } + // Even if the test fails because the timeout doesn't work, closing the + // socket at the end of the test should result in an EAGAIN (as long as + // TestIfSocketCloses completed, otherwise this test will leak the + // goroutine running the Receive). + defer nlSock.Close() + + // Set a short timeout on the Receive and a much longer timeout on + // the test so that, even if there are neighbour changes, it's very + // likely there will be a timeout. + nlSock.SetReceiveTimeout(&unix.Timeval{Sec: 0, Usec: 100000}) + const failAfter = 5 * time.Second + + doneC := make(chan struct{}) + go func() { + for { + if _, _, err := nlSock.Receive(); err == unix.EAGAIN { + doneC <- struct{}{} + return + } + } + }() + + failTimerC := time.After(failAfter) + select { + case <-doneC: + // Test passed + case <-failTimerC: + t.Fatalf("No timeout received") + } +} + func (msg *CnMsgOp) write(b []byte) { native := NativeEndian() native.PutUint32(b[0:4], msg.ID.Idx)