From 6f81931c62d93ba1859b4ff0db92a305bc4c4159 Mon Sep 17 00:00:00 2001 From: Sven Rebhan Date: Wed, 15 Nov 2023 15:50:31 +0100 Subject: [PATCH] Implement UDP socket diagnostics Signed-off-by: Sven Rebhan --- go.mod | 1 + go.sum | 17 ++++++ inet_diag.go | 5 ++ socket_linux.go | 142 +++++++++++++++++++++++++++++++++++++++--------- socket_test.go | 74 +++++++++++-------------- tcp.go | 8 +++ tcp_linux.go | 15 +++++ 7 files changed, 194 insertions(+), 68 deletions(-) diff --git a/go.mod b/go.mod index dbf5f205..92928c51 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/vishvananda/netlink go 1.12 require ( + github.com/stretchr/testify v1.8.4 github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae golang.org/x/sys v0.10.0 ) diff --git a/go.sum b/go.sum index e56e1f08..7e4b8b20 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,22 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inet_diag.go b/inet_diag.go index bee391a8..a483ee1a 100644 --- a/inet_diag.go +++ b/inet_diag.go @@ -29,3 +29,8 @@ type InetDiagTCPInfoResp struct { TCPInfo *TCPInfo TCPBBRInfo *TCPBBRInfo } + +type InetDiagUDPInfoResp struct { + InetDiagMsg *Socket + Memory *MemInfo +} diff --git a/socket_linux.go b/socket_linux.go index b881fe49..f39c2a6a 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -174,8 +174,18 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { + // Construct the request + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_TCP, + Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result var result []*InetDiagTCPInfoResp - err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error { + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { sockInfo := &Socket{} if err := sockInfo.deserialize(m.Data); err != nil { return err @@ -201,8 +211,18 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { // SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. func SocketDiagTCP(family uint8) ([]*Socket, error) { + // Construct the request + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_TCP, + Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result var result []*Socket - err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error { + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { sockInfo := &Socket{} if err := sockInfo.deserialize(m.Data); err != nil { return err @@ -216,21 +236,82 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) { return result, nil } -// socketDiagTCPExecutor requests INET_DIAG_INFO for TCP protocol for specified family type. -func socketDiagTCPExecutor(family uint8, receiver func(syscall.NetlinkMessage) error) error { - s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) +// SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info. +func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { + // Construct the request + var extensions uint8 + extensions = 1 << (INET_DIAG_VEGASINFO - 1) + extensions |= 1 << (INET_DIAG_INFO - 1) + extensions |= 1 << (INET_DIAG_MEMINFO - 1) + + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_UDP, + Ext: extensions, + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result + var result []*InetDiagUDPInfoResp + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + sockInfo := &Socket{} + if err := sockInfo.deserialize(m.Data); err != nil { + return err + } + attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:]) + if err != nil { + return err + } + + res, err := attrsToInetDiagUDPInfoResp(attrs, sockInfo) + if err != nil { + return err + } + + result = append(result, res) + return nil + }) if err != nil { - return err + return nil, err } - defer s.Close() + return result, nil +} +// SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket. +func SocketDiagUDP(family uint8) ([]*Socket, error) { + // Construct the request req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ Family: family, - Protocol: unix.IPPROTO_TCP, + Protocol: unix.IPPROTO_UDP, Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), - States: uint32(0xfff), // All TCP states + States: uint32(0xfff), // all states }) + + // Do the query and parse the result + var result []*Socket + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + sockInfo := &Socket{} + if err := sockInfo.deserialize(m.Data); err != nil { + return err + } + result = append(result, sockInfo) + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +// socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request. +func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error { + s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) + if err != nil { + return err + } + defer s.Close() s.Send(req) loop: @@ -240,7 +321,7 @@ loop: return err } if from.Pid != nl.PidKernel { - return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) + return fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) } if len(msgs) == 0 { return errors.New("no message nor error from netlink") @@ -263,29 +344,40 @@ loop: } func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) { - var tcpInfo *TCPInfo - var tcpBBRInfo *TCPBBRInfo + info := &InetDiagTCPInfoResp{ + InetDiagMsg: sockInfo, + } for _, a := range attrs { - if a.Attr.Type == INET_DIAG_INFO { - tcpInfo = &TCPInfo{} - if err := tcpInfo.deserialize(a.Value); err != nil { + switch a.Attr.Type { + case INET_DIAG_INFO: + info.TCPInfo = &TCPInfo{} + if err := info.TCPInfo.deserialize(a.Value); err != nil { + return nil, err + } + case INET_DIAG_BBRINFO: + info.TCPBBRInfo = &TCPBBRInfo{} + if err := info.TCPBBRInfo.deserialize(a.Value); err != nil { return nil, err } - continue } + } + + return info, nil +} - if a.Attr.Type == INET_DIAG_BBRINFO { - tcpBBRInfo = &TCPBBRInfo{} - if err := tcpBBRInfo.deserialize(a.Value); err != nil { +func attrsToInetDiagUDPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagUDPInfoResp, error) { + info := &InetDiagUDPInfoResp{ + InetDiagMsg: sockInfo, + } + for _, a := range attrs { + switch a.Attr.Type { + case INET_DIAG_MEMINFO: + info.Memory = &MemInfo{} + if err := info.Memory.deserialize(a.Value); err != nil { return nil, err } - continue } } - return &InetDiagTCPInfoResp{ - InetDiagMsg: sockInfo, - TCPInfo: tcpInfo, - TCPBBRInfo: tcpBBRInfo, - }, nil + return info, nil } diff --git a/socket_test.go b/socket_test.go index f21520f1..bf514d24 100644 --- a/socket_test.go +++ b/socket_test.go @@ -1,77 +1,65 @@ +//go:build linux // +build linux package netlink import ( - "log" "net" "os/user" "strconv" "syscall" "testing" + + "github.com/stretchr/testify/require" ) func TestSocketGet(t *testing.T) { defer setUpNetlinkTestWithLoopback(t)() addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + l, err := net.ListenTCP("tcp", addr) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) defer l.Close() conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer conn.Close() localAddr := conn.LocalAddr().(*net.TCPAddr) remoteAddr := conn.RemoteAddr().(*net.TCPAddr) socket, err := SocketGet(localAddr, remoteAddr) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + require.EqualValues(t, localAddr.IP.To16(), socket.ID.Source.To16(), "local ip") + require.EqualValues(t, remoteAddr.IP.To16(), socket.ID.Destination.To16(), "remote ip") + require.EqualValues(t, localAddr.Port, socket.ID.SourcePort, "local port") + require.EqualValues(t, remoteAddr.Port, socket.ID.DestinationPort, "remote port") - if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) { - t.Fatalf("local ip = %v, want %v", got, want) - } - if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) { - t.Fatalf("remote ip = %v, want %v", got, want) - } - if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want { - t.Fatalf("local port = %d, want %d", got, want) - } - if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want { - t.Fatalf("remote port = %d, want %d", got, want) - } u, err := user.Current() - if err != nil { - t.Fatal(err) - } - if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want { - t.Fatalf("UID = %s, want %s", got, want) - } + require.NoError(t, err) + require.EqualValues(t, u.Uid, strconv.Itoa(int(socket.UID)), "UID") } func TestSocketDiagTCPInfo(t *testing.T) { - Family4 := uint8(syscall.AF_INET) - Family6 := uint8(syscall.AF_INET6) - families := []uint8{Family4, Family6} - for _, wantFamily := range families { - res, err := SocketDiagTCPInfo(wantFamily) - if err != nil { - t.Fatal(err) + for _, expected := range []uint8{syscall.AF_INET, syscall.AF_INET6} { + result, err := SocketDiagTCPInfo(expected) + require.NoError(t, err) + + for _, i := range result { + require.Equal(t, expected, i.InetDiagMsg.Family) } - for _, i := range res { - gotFamily := i.InetDiagMsg.Family - if gotFamily != wantFamily { - t.Fatalf("Socket family = %d, want %d", gotFamily, wantFamily) - } + } +} + +func TestSocketDiagUDPnfo(t *testing.T) { + for _, expected := range []uint8{syscall.AF_INET, syscall.AF_INET6} { + result, err := SocketDiagUDPInfo(expected) + require.NoError(t, err) + + for _, i := range result { + require.Equal(t, expected, i.InetDiagMsg.Family) } } } diff --git a/tcp.go b/tcp.go index 23ca014d..43f80a0f 100644 --- a/tcp.go +++ b/tcp.go @@ -82,3 +82,11 @@ type TCPBBRInfo struct { BBRPacingGain uint32 BBRCwndGain uint32 } + +// According to https://man7.org/linux/man-pages/man7/sock_diag.7.html +type MemInfo struct { + RMem uint32 + WMem uint32 + FMem uint32 + TMem uint32 +} diff --git a/tcp_linux.go b/tcp_linux.go index 29385873..e98036da 100644 --- a/tcp_linux.go +++ b/tcp_linux.go @@ -8,6 +8,7 @@ import ( const ( tcpBBRInfoLen = 20 + memInfoLen = 16 ) func checkDeserErr(err error) error { @@ -351,3 +352,17 @@ func (t *TCPBBRInfo) deserialize(b []byte) error { return nil } + +func (m *MemInfo) deserialize(b []byte) error { + if len(b) != memInfoLen { + return errors.New("Invalid length") + } + + rb := bytes.NewBuffer(b) + m.RMem = native.Uint32(rb.Next(4)) + m.WMem = native.Uint32(rb.Next(4)) + m.FMem = native.Uint32(rb.Next(4)) + m.TMem = native.Uint32(rb.Next(4)) + + return nil +}