Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unix socket diagnostics #946

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,13 @@ type Socket struct {
UID uint32
INode uint32
}

// UnixSocket represents a netlink unix socket.
type UnixSocket struct {
Type uint8
Family uint8
State uint8
pad uint8
INode uint32
Cookie [2]uint32
}
148 changes: 144 additions & 4 deletions socket_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
)

const (
sizeofSocketID = 0x30
sizeofSocketRequest = sizeofSocketID + 0x8
sizeofSocket = sizeofSocketID + 0x18
sizeofSocketID = 0x30
sizeofSocketRequest = sizeofSocketID + 0x8
sizeofSocket = sizeofSocketID + 0x18
sizeofUnixSocketRequest = 0x18 // 24 byte
sizeofUnixSocket = 0x10 // 16 byte
)

type socketRequest struct {
Expand Down Expand Up @@ -67,6 +69,32 @@ func (r *socketRequest) Serialize() []byte {

func (r *socketRequest) Len() int { return sizeofSocketRequest }

// According to linux/include/uapi/linux/unix_diag.h
type unixSocketRequest struct {
Family uint8
Protocol uint8
pad uint16
States uint32
INode uint32
Show uint32
Cookie [2]uint32
}

func (r *unixSocketRequest) Serialize() []byte {
b := writeBuffer{Bytes: make([]byte, sizeofUnixSocketRequest)}
b.Write(r.Family)
b.Write(r.Protocol)
native.PutUint16(b.Next(2), r.pad)
native.PutUint32(b.Next(4), r.States)
native.PutUint32(b.Next(4), r.INode)
native.PutUint32(b.Next(4), r.Show)
native.PutUint32(b.Next(4), r.Cookie[0])
native.PutUint32(b.Next(4), r.Cookie[1])
return b.Bytes
}

func (r *unixSocketRequest) Len() int { return sizeofUnixSocketRequest }

type readBuffer struct {
Bytes []byte
pos int
Expand Down Expand Up @@ -115,6 +143,21 @@ func (s *Socket) deserialize(b []byte) error {
return nil
}

func (u *UnixSocket) deserialize(b []byte) error {
if len(b) < sizeofUnixSocket {
return fmt.Errorf("unix diag data short read (%d); want %d", len(b), sizeofUnixSocket)
}
rb := readBuffer{Bytes: b}
u.Type = rb.Read()
u.Family = rb.Read()
u.State = rb.Read()
u.pad = rb.Read()
u.INode = native.Uint32(rb.Next(4))
u.Cookie[0] = native.Uint32(rb.Next(4))
u.Cookie[1] = native.Uint32(rb.Next(4))
return nil
}

// SocketGet returns the Socket identified by its local and remote addresses.
func SocketGet(local, remote net.Addr) (*Socket, error) {
localTCP, ok := local.(*net.TCPAddr)
Expand Down Expand Up @@ -157,7 +200,7 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
return nil, err
}
if from.Pid != nl.PidKernel {
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
return nil, fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
}
if len(msgs) == 0 {
return nil, errors.New("no message nor error from netlink")
Expand Down Expand Up @@ -305,6 +348,78 @@ func SocketDiagUDP(family uint8) ([]*Socket, error) {
return result, nil
}

// UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info.
func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) {
// Construct the request
var extensions uint8
extensions = 1 << UNIX_DIAG_NAME
extensions |= 1 << UNIX_DIAG_PEER
extensions |= 1 << UNIX_DIAG_RQLEN
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&unixSocketRequest{
Family: unix.AF_UNIX,
States: ^uint32(0), // all states
Show: uint32(extensions),
})

var result []*UnixDiagInfoResp
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &UnixSocket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
}

// Diagnosis also delivers sockets with AF_INET family, filter those
if sockInfo.Family != unix.AF_UNIX {
return nil
}

attrs, err := nl.ParseRouteAttr(m.Data[sizeofUnixSocket:])
if err != nil {
return err
}

res, err := attrsToUnixDiagInfoResp(attrs, sockInfo)
if err != nil {
return err
}
result = append(result, res)
return nil
})
if err != nil {
return nil, err
}
return result, nil
}

// UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets.
func UnixSocketDiag() ([]*UnixSocket, error) {
// Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&unixSocketRequest{
Family: unix.AF_UNIX,
States: ^uint32(0), // all states
})

var result []*UnixSocket
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &UnixSocket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
}

// Diagnosis also delivers sockets with AF_INET family, filter those
if sockInfo.Family == unix.AF_UNIX {
result = append(result, sockInfo)
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}

// socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request.
func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error {
s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
Expand Down Expand Up @@ -381,3 +496,28 @@ func attrsToInetDiagUDPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Sock

return info, nil
}

func attrsToUnixDiagInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *UnixSocket) (*UnixDiagInfoResp, error) {
info := &UnixDiagInfoResp{
DiagMsg: sockInfo,
}
for _, a := range attrs {
switch a.Attr.Type {
case UNIX_DIAG_NAME:
name := string(a.Value[:a.Attr.Len])
info.Name = &name
case UNIX_DIAG_PEER:
peer := native.Uint32(a.Value)
info.Peer = &peer
case UNIX_DIAG_RQLEN:
info.Queue = &QueueInfo{
RQueue: native.Uint32(a.Value[:4]),
WQueue: native.Uint32(a.Value[4:]),
}
// default:
// fmt.Println("unknown unix attribute type", a.Attr.Type, "with data", a.Value)
}
}

return info, nil
}
16 changes: 16 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package netlink

import (
"fmt"
"log"
"net"
"os/user"
Expand Down Expand Up @@ -91,3 +92,18 @@ func TestSocketDiagUDPnfo(t *testing.T) {
}
}
}

func TestUnixSocketDiagInfo(t *testing.T) {
want := syscall.AF_UNIX
result, err := UnixSocketDiagInfo()
if err != nil {
t.Fatal(err)
}

for i, r := range result {
fmt.Println(r.DiagMsg)
if got := r.DiagMsg.Family; got != uint8(want) {
t.Fatalf("%d: protocol family = %v, want %v", i, got, want)
}
}
}
27 changes: 27 additions & 0 deletions unix_diag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package netlink

// According to linux/include/uapi/linux/unix_diag.h
const (
UNIX_DIAG_NAME = iota
UNIX_DIAG_VFS
UNIX_DIAG_PEER
UNIX_DIAG_ICONS
UNIX_DIAG_RQLEN
UNIX_DIAG_MEMINFO
UNIX_DIAG_SHUTDOWN
UNIX_DIAG_UID
UNIX_DIAG_MAX
)

type UnixDiagInfoResp struct {
DiagMsg *UnixSocket
Name *string
Peer *uint32
Queue *QueueInfo
Shutdown *uint8
}

type QueueInfo struct {
RQueue uint32
WQueue uint32
}
Loading