forked from projectdiscovery/mapcidr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcidr.go
322 lines (282 loc) · 8 KB
/
cidr.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
// Package mapcidr implements methods to allow working with CIDRs.
package mapcidr
import (
"fmt"
"math"
"math/big"
"net"
)
// AddressRange returns the first and last addresses in the given CIDR range.
func AddressRange(network *net.IPNet) (firstIP, lastIP net.IP, err error) {
firstIP = network.IP
prefixLen, bits := network.Mask.Size()
if prefixLen == bits {
lastIP := make([]byte, len(firstIP))
copy(lastIP, firstIP)
return firstIP, lastIP, nil
}
firstIPInt, bits, err := IPToInteger(firstIP)
if err != nil {
return nil, nil, err
}
hostLen := uint(bits) - uint(prefixLen)
lastIPInt := big.NewInt(1)
lastIPInt.Lsh(lastIPInt, hostLen)
lastIPInt.Sub(lastIPInt, big.NewInt(1))
lastIPInt.Or(lastIPInt, firstIPInt)
lastIP = IntegerToIP(lastIPInt, bits)
return
}
// AddressCount returns the number of IP addresses in a range
func AddressCount(cidr string) (uint64, error) {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return 0, err
}
return AddressCountIpnet(ipnet), nil
}
// AddressCountIpnet returns the number of IP addresses in an IPNet structure
func AddressCountIpnet(network *net.IPNet) uint64 {
prefixLen, bits := network.Mask.Size()
return 1 << (uint64(bits) - uint64(prefixLen))
}
// SplitByNumber splits the given cidr into subnets with the closest
// number of hosts per subnet.
func SplitByNumber(iprange string, number int) ([]*net.IPNet, error) {
_, ipnet, err := net.ParseCIDR(iprange)
if err != nil {
return nil, err
}
return SplitIPNetByNumber(ipnet, number)
}
// SplitIPNetByNumber splits an IPNet into subnets with the closest n
// umber of hosts per subnet.
func SplitIPNetByNumber(ipnet *net.IPNet, number int) ([]*net.IPNet, error) {
ipsNumber := AddressCountIpnet(ipnet)
// truncate result to nearest uint64
optimalSplit := int(ipsNumber / uint64(number))
return SplitIPNetIntoN(ipnet, optimalSplit)
}
// SplitN attempts to split a cidr in the exact number of subnets
func SplitN(iprange string, n int) ([]*net.IPNet, error) {
_, ipnet, err := net.ParseCIDR(iprange)
if err != nil {
return nil, err
}
return SplitIPNetIntoN(ipnet, n)
}
// SplitIPNetIntoN attempts to split a ipnet in the exact number of subnets
func SplitIPNetIntoN(iprange *net.IPNet, n int) ([]*net.IPNet, error) {
var err error
subnets := make([]*net.IPNet, 0, n)
// invalid value
if n <= 1 || AddressCountIpnet(iprange) < uint64(n) {
subnets = append(subnets, iprange)
return subnets, nil
}
// power of two
if isPowerOfTwo(n) || isPowerOfTwoPlusOne(n) {
return splitIPNet(iprange, n)
}
var closestMinorPowerOfTwo int
// find the closest power of two in a stupid way
for i := n; i > 0; i-- {
if isPowerOfTwo(i) {
closestMinorPowerOfTwo = i
break
}
}
subnets, err = splitIPNet(iprange, closestMinorPowerOfTwo)
if err != nil {
return nil, err
}
for len(subnets) < n {
var newSubnets []*net.IPNet
level := 1
for i := len(subnets) - 1; i >= 0; i-- {
divided, err := divideIPNet(subnets[i])
if err != nil {
return nil, err
}
newSubnets = append(newSubnets, divided...)
if len(subnets)-level+len(newSubnets) == n {
reverseIPNet(newSubnets)
subnets = subnets[:len(subnets)-level]
subnets = append(subnets, newSubnets...)
return subnets, nil
}
level++
}
reverseIPNet(newSubnets)
subnets = newSubnets
}
return subnets, nil
}
// divideIPNet divides an IPNet into two IPNet structures.
func divideIPNet(ipnet *net.IPNet) ([]*net.IPNet, error) {
subnets := make([]*net.IPNet, 0, 2) //nolint
maskBits, _ := ipnet.Mask.Size()
wantedMaskBits := maskBits + 1
currentSubnet, err := currentSubnet(ipnet, wantedMaskBits)
if err != nil {
return nil, err
}
subnets = append(subnets, currentSubnet)
nextSubnet, err := nextSubnet(currentSubnet, wantedMaskBits)
if err != nil {
return nil, err
}
subnets = append(subnets, nextSubnet)
return subnets, nil
}
// splitIPNet into approximate N counts
func splitIPNet(ipnet *net.IPNet, n int) ([]*net.IPNet, error) {
var err error
subnets := make([]*net.IPNet, 0, n)
maskBits, _ := ipnet.Mask.Size()
closestPow2 := int(closestPowerOfTwo(uint32(n)))
pow2 := int(math.Log2(float64(closestPow2)))
wantedMaskBits := maskBits + pow2
currentSubnet, err := currentSubnet(ipnet, wantedMaskBits)
if err != nil {
return nil, err
}
subnets = append(subnets, currentSubnet)
nxtSubnet := currentSubnet
for i := 0; i < closestPow2-1; i++ {
nxtSubnet, err = nextSubnet(nxtSubnet, wantedMaskBits)
if err != nil {
return nil, err
}
subnets = append(subnets, nxtSubnet)
}
if len(subnets) < n {
lastSubnet := subnets[len(subnets)-1]
subnets = subnets[:len(subnets)-1]
ipnets, err := divideIPNet(lastSubnet)
if err != nil {
return nil, err
}
subnets = append(subnets, ipnets...)
}
return subnets, nil
}
// func split(iprange string, n int) ([]*net.IPNet, error) {
// _, ipnet, _ := net.ParseCIDR(iprange)
// return splitIPNet(ipnet, n)
// }
func nextPowerOfTwo(v uint32) uint32 {
v--
v |= v >> 1
v |= v >> 2
v |= v >> 4
v |= v >> 8
v |= v >> 16
v++
return v
}
func closestPowerOfTwo(v uint32) uint32 {
next := nextPowerOfTwo(v)
if prev := next / 2; (v - prev) < (next - v) {
next = prev
}
return next
}
func currentSubnet(network *net.IPNet, prefixLen int) (*net.IPNet, error) {
currentFirst, _, err := AddressRange(network)
if err != nil {
return nil, err
}
mask := net.CIDRMask(prefixLen, 8*len(currentFirst)) //nolint
return &net.IPNet{IP: currentFirst.Mask(mask), Mask: mask}, nil
}
// nextSubnet returns the next subnet for an ipnet
func nextSubnet(network *net.IPNet, prefixLen int) (*net.IPNet, error) {
_, currentLast, err := AddressRange(network)
if err != nil {
return nil, err
}
mask := net.CIDRMask(prefixLen, 8*len(currentLast)) //nolint
currentSubnet := &net.IPNet{IP: currentLast.Mask(mask), Mask: mask}
_, last, err := AddressRange(currentSubnet)
if err != nil {
return nil, err
}
last = inc(last)
next := &net.IPNet{IP: last.Mask(mask), Mask: mask}
if last.Equal(net.IPv4zero) || last.Equal(net.IPv6zero) {
return next, nil
}
return next, nil
}
func isPowerOfTwoPlusOne(x int) bool {
return isPowerOfTwo(x - 1)
}
// isPowerOfTwo returns if a number is a power of 2
func isPowerOfTwo(x int) bool {
return x != 0 && (x&(x-1)) == 0
}
// reverseIPNet reverses an ipnet slice
func reverseIPNet(ipnets []*net.IPNet) {
for i, j := 0, len(ipnets)-1; i < j; i, j = i+1, j-1 {
ipnets[i], ipnets[j] = ipnets[j], ipnets[i]
}
}
// IPAddresses returns all the IP addresses in a CIDR
func IPAddresses(cidr string) ([]string, error) {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return []string{}, err
}
return IPAddressesIPnet(ipnet), nil
}
func IPAddressesAsStream(cidr string) (chan string, error) {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
return IpAddresses(ipnet), nil
}
// IPAddressesIPnet returns all IP addresses in an IPNet.
func IPAddressesIPnet(ipnet *net.IPNet) (ips []string) {
for ip := range IpAddresses(ipnet) {
ips = append(ips, ip)
}
return ips
}
// IpAddresses as stream
func IpAddresses(ipnet *net.IPNet) (ips chan string) {
ips = make(chan string)
go func() {
defer close(ips)
netWithRange := ipNetToRange(*ipnet)
for ip := *netWithRange.First; !ip.Equal(*netWithRange.Last); ip = GetNextIP(ip) {
ips <- ip.String()
}
// Add the last IP
ips <- netWithRange.Last.String()
}()
return ips
}
// IPToInteger converts an IP address to its integer representation.
// It supports both IPv4 as well as IPv6 addresses.
func IPToInteger(ip net.IP) (*big.Int, int, error) {
val := &big.Int{}
val.SetBytes([]byte(ip))
if len(ip) == net.IPv4len {
return val, 32, nil //nolint
} else if len(ip) == net.IPv6len {
return val, 128, nil //nolint
} else {
return nil, 0, fmt.Errorf("unsupported address length %d", len(ip))
}
}
// IntegerToIP converts an Integer IP address to net.IP format.
func IntegerToIP(ipInt *big.Int, bits int) net.IP {
ipBytes := ipInt.Bytes()
ret := make([]byte, bits/8) //nolint
for i := 1; i <= len(ipBytes); i++ {
ret[len(ret)-i] = ipBytes[len(ipBytes)-i]
}
return net.IP(ret)
}