diff --git a/route_linux.go b/route_linux.go index 6cd329f8..5bebfd4e 100644 --- a/route_linux.go +++ b/route_linux.go @@ -1148,7 +1148,7 @@ func (h *Handle) RouteListFiltered(family int, filter *Route, filterMask uint64) var res []Route for _, m := range msgs { msg := nl.DeserializeRtMsg(m) - if msg.Family != uint8(family) { + if family != FAMILY_ALL && msg.Family != uint8(family) { // Ignore routes not matching requested family continue } diff --git a/route_test.go b/route_test.go index 166c357e..c73205f2 100644 --- a/route_test.go +++ b/route_test.go @@ -876,6 +876,79 @@ func TestRouteFilterAllTables(t *testing.T) { } } +func TestRouteFilterByFamily(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + + const table int = 999 + + // get loopback interface + link, err := LinkByName("lo") + if err != nil { + t.Fatal(err) + } + // bring the interface up + if err = LinkSetUp(link); err != nil { + t.Fatal(err) + } + + // add a IPv4 gateway route + dst4 := &net.IPNet{ + IP: net.IPv4(2, 2, 0, 0), + Mask: net.CIDRMask(24, 32), + } + route4 := Route{LinkIndex: link.Attrs().Index, Dst: dst4, Table: table} + if err := RouteAdd(&route4); err != nil { + t.Fatal(err) + } + + // add a IPv6 gateway route + dst6 := &net.IPNet{ + IP: net.ParseIP("2001:db9::0"), + Mask: net.CIDRMask(64, 128), + } + route6 := Route{LinkIndex: link.Attrs().Index, Dst: dst6, Table: table} + if err := RouteAdd(&route6); err != nil { + t.Fatal(err) + } + + // Get routes for both families + routes_all, err := RouteListFiltered(FAMILY_ALL, &Route{Table: table}, RT_FILTER_TABLE) + if err != nil { + t.Fatal(err) + } + if len(routes_all) != 2 { + t.Fatal("Filtering by FAMILY_ALL doesn't find two routes") + } + + // Get IPv4 route + routes_v4, err := RouteListFiltered(FAMILY_V4, &Route{Table: table}, RT_FILTER_TABLE) + if err != nil { + t.Fatal(err) + } + if len(routes_v4) != 1 { + t.Fatal("Filtering by FAMILY_V4 doesn't find one route") + } + + // Get IPv6 route + routes_v6, err := RouteListFiltered(FAMILY_V6, &Route{Table: table}, RT_FILTER_TABLE) + if err != nil { + t.Fatal(err) + } + if len(routes_v6) != 1 { + t.Fatal("Filtering by FAMILY_V6 doesn't find one route") + } + + // Get non-existent routes + routes_non_existent, err := RouteListFiltered(99, &Route{Table: table}, RT_FILTER_TABLE) + if err != nil { + t.Fatal(err) + } + if len(routes_non_existent) != 0 { + t.Fatal("Filtering by non-existent family find some route") + } +} + func tableIDIn(ids []int, id int) bool { for _, v := range ids { if v == id {