Skip to content

Commit

Permalink
Stop processing responses not belonging to a request (#4)
Browse files Browse the repository at this point in the history
When running as root (or setcap) (network ipX), the listeners started
with each call to `Do` would get a copy of all the incoming responses.
This change allows requests to be sent to multiple hosts concurrently.
Only the request's associated responses are processed.

There can be a further check of ID, but when using udpX pings
(`net.ipv4.ping_group_range` set), the ID in the response is always a
completely different number and causes timeouts. It is unclear why this
number is so off.
  • Loading branch information
glinton authored Oct 22, 2019
1 parent f160bc1 commit d3c0ecf
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 60 deletions.
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7 h1:rTIdg5QFRR7XCaK4LCjBiPbx8j4DQRpdYMnGn/bJUEU=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191021144547-ec77196f6094 h1:5O4U9trLjNpuhpynaDsqwCk+Tw6seqJz1EbqbnzHrc8=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
78 changes: 26 additions & 52 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,29 +108,20 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
}
}

var (
resp *Response
readErr error
)

sentAt, err := send(ctx, conn, req)
if err != nil {
return nil, err
}

resp, readErr = read(ctx, conn)
if readErr != nil {
return nil, readErr
resp, err := read(ctx, conn, req)
if err != nil {
return nil, err
}

resp.RTT = resp.rcvdAt.Sub(sentAt)
req.sentAt = sentAt
resp.Req = req

if readErr != nil {
return nil, readErr
}

return resp, nil
}

Expand Down Expand Up @@ -199,18 +190,17 @@ func (req *Request) proto() int {
return protocolIPv4ICMP
}

func read(ctx context.Context, conn *icmp.PacketConn) (*Response, error) {
func read(ctx context.Context, conn *icmp.PacketConn, req *Request) (*Response, error) {
if c4 := conn.IPv4PacketConn(); c4 != nil {
return read4(ctx, c4)
return read4(ctx, c4, req)
}
c6 := conn.IPv6PacketConn()
if c6 == nil {
return nil, errors.New("bad icmp connection type")
if c6 := conn.IPv6PacketConn(); c6 != nil {
return read6(ctx, c6, req)
}
return read6(ctx, c6)
return nil, errors.New("bad icmp connection type")
}

func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) {
func read4(ctx context.Context, conn *ipv4.PacketConn, req *Request) (*Response, error) {
for {
select {
case <-ctx.Done():
Expand All @@ -232,7 +222,7 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) {
return nil, err
}

if n <= 0 {
if cm == nil || n <= 0 || cm.Src.String() != req.Dst.String() {
continue
}

Expand All @@ -246,36 +236,28 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) {
continue
}

var seq uint
var id int
b, ok := m.Body.(*icmp.Echo)
if ok {
seq = uint(b.Seq)
id = b.ID
}

var ttl int
if cm != nil {
ttl = cm.TTL
if !ok || b.Seq != req.Seq {
continue
}

srcHost, _, _ := net.SplitHostPort(src.String())
dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String())
return &Response{
ID: id,
Seq: seq,
Data: bytesReceived[:n],
ID: b.ID,
Seq: uint(b.Seq),
Data: b.Data,
TotalLength: n,
Src: net.ParseIP(srcHost),
Dst: net.ParseIP(dstHost),
TTL: ttl,
TTL: cm.TTL,
rcvdAt: rcv,
}, nil
}
}
}

func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) {
func read6(ctx context.Context, conn *ipv6.PacketConn, req *Request) (*Response, error) {
for {
select {
case <-ctx.Done():
Expand All @@ -297,7 +279,7 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) {
return nil, err
}

if n <= 0 {
if cm == nil || n <= 0 || cm.Src.String() != req.Dst.String() {
continue
}

Expand All @@ -311,29 +293,21 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) {
continue
}

var seq uint
var id int
b, ok := m.Body.(*icmp.Echo)
if ok {
seq = uint(b.Seq)
id = b.ID
}

var ttl int
if cm != nil {
ttl = cm.HopLimit
if !ok || b.Seq != req.Seq {
continue
}

srcHost, _, _ := net.SplitHostPort(src.String())
dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String())
return &Response{
ID: id,
Seq: seq,
ID: b.ID,
Seq: uint(b.Seq),
Data: bytesReceived[:n],
TotalLength: n,
Src: net.ParseIP(srcHost),
Dst: net.ParseIP(dstHost),
TTL: ttl,
TTL: cm.HopLimit,
rcvdAt: rcv,
}, nil
}
Expand All @@ -355,7 +329,7 @@ func send(ctx context.Context, conn *icmp.PacketConn, req *Request) (time.Time,
default:
body := &icmp.Echo{
ID: req.ID,
Seq: int(req.Seq),
Seq: req.Seq,
Data: req.data(),
}

Expand All @@ -367,9 +341,9 @@ func send(ctx context.Context, conn *icmp.PacketConn, req *Request) (time.Time,

if req.proto() == protocolIPv4ICMP {
msg.Type = ipv4.ICMPTypeEcho
conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true)
conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL|ipv4.FlagSrc|ipv4.FlagDst, true)
} else {
conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true)
conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit|ipv6.FlagSrc|ipv6.FlagDst, true)
}

msgBytes, err := msg.Marshal(nil)
Expand Down
17 changes: 9 additions & 8 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ import (
func TestE2E(t *testing.T) {
c := &ping.Client{}

hostIPs := []string{"8.8.8.8", "8.8.4.4", "1.1.1.1"}
hostIPs := []string{"8.8.8.8", "8.8.4.4", "1.1.1.1", "10.10.10.10", "127.0.0.1"}
count := 3
deadline := time.Second * 5
timeout := time.Second * 2
interval := time.Second
pwg := &sync.WaitGroup{}

for _, hostIP := range hostIPs {
for xx, hostIP := range hostIPs {
pwg.Add(1)
go func(host string) {
go func(x int, host string) {
defer pwg.Done()

tick := time.NewTicker(interval)
Expand All @@ -48,10 +48,11 @@ func TestE2E(t *testing.T) {

packetsSent++
wg.Add(1)
go func(seq int) {
go func(id, seq int) {
defer wg.Done()
resp, err := c.Do(ctx, &ping.Request{
Dst: net.ParseIP(host),
ID: id + 1,
Seq: seq,
})
if err != nil {
Expand All @@ -61,7 +62,7 @@ func TestE2E(t *testing.T) {

resps <- resp
onRcv(resp)
}(packetsSent)
}(x, packetsSent)
}
}

Expand All @@ -74,14 +75,14 @@ func TestE2E(t *testing.T) {
}
onFin(packetsSent, rsps)
fmt.Println()
}(hostIP)
}(xx, hostIP)
}
pwg.Wait()
}

func onRcv(res *ping.Response) {
fmt.Printf("%d bytes from %s: icmp_seq=%d time=%v ttl=%v\n",
res.TotalLength, res.Src.String(), res.Seq, res.RTT, res.TTL)
res.TotalLength, res.Req.Dst.String(), res.Seq, res.RTT, res.TTL)
}

func onFin(packetsSent int, resps []*ping.Response) {
Expand Down Expand Up @@ -111,7 +112,7 @@ func onFin(packetsSent int, resps []*ping.Response) {
}
stdDev := time.Duration(math.Sqrt(float64(sumsquares / time.Duration(len(resps)))))

fmt.Printf("\n--- %s ping statistics ---\n", resps[0].Src.String())
fmt.Printf("\n--- %s ping statistics ---\n", resps[0].Req.Dst.String())
fmt.Printf("%d packets transmitted, %d packets received, %.2f%% packet loss\n",
packetsSent, len(resps), float64(packetsSent-len(resps))/float64(packetsSent)*100)
fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n",
Expand Down

0 comments on commit d3c0ecf

Please sign in to comment.