diff --git a/net/ip.go b/net/ip.go index a9223fe400aca86e68bcf4fb3a21051478105d59..428a431b6bb6a8fad4e24fe66a9d8aad6482fe52 100644 --- a/net/ip.go +++ b/net/ip.go @@ -57,6 +57,14 @@ func (ip *IP) Higher() uint64 { return ip.higher } +func (ip *IP) copy() *IP { + return &IP{ + higher: ip.higher, + lower: ip.lower, + isLegacy: ip.isLegacy, + } +} + // IPv4 returns a new `IP` representing an IPv4 address func IPv4(val uint32) IP { return IP{ @@ -285,3 +293,19 @@ func (ip *IP) bitAtPositionIPv6(pos uint8) bool { return (ip.lower & (1 << (128 - pos))) != 0 } + +// Next gets the next ip address +func (ip *IP) Next() *IP { + newIP := ip.copy() + if ip.isLegacy { + newIP.lower++ + return newIP + } + + newIP.lower++ + if newIP.lower == 0 { + newIP.higher++ + } + + return newIP +} diff --git a/net/ip_test.go b/net/ip_test.go index fb666ca23c8af8e78a01d9f2dde4f23f0c835f0b..279909f7053415367cf48c4c05208aa8db73ca03 100644 --- a/net/ip_test.go +++ b/net/ip_test.go @@ -589,3 +589,31 @@ func TestSizeBytes(t *testing.T) { assert.Equal(t, test.expected, test.input.SizeBytes(), test.name) } } + +func TestNext(t *testing.T) { + tests := []struct { + name string + input *IP + expected *IP + }{ + { + name: "Test #1", + input: IPv4FromOctets(10, 0, 0, 1).Dedup(), + expected: IPv4FromOctets(10, 0, 0, 2).Dedup(), + }, + { + name: "Test #2", + input: IPv6FromBlocks(10, 20, 30, 40, 50, 60, 70, 80).Dedup(), + expected: IPv6FromBlocks(10, 20, 30, 40, 50, 60, 70, 81).Dedup(), + }, + { + name: "Test #3", + input: IPv6FromBlocks(10, 20, 30, 40, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF).Dedup(), + expected: IPv6FromBlocks(10, 20, 30, 41, 0, 0, 0, 0).Dedup(), + }, + } + + for _, test := range tests { + assert.Equal(t, test.expected, test.input.Next(), test.name) + } +} diff --git a/net/prefix.go b/net/prefix.go index 8905a9b3cf64c798fb018df5028b1224ffa02860..f0a2f91f2094618d0ffd5a48741fa0fb2360ed9a 100644 --- a/net/prefix.go +++ b/net/prefix.go @@ -265,3 +265,36 @@ func checkLastNBitsUint32(x uint32, n uint8) bool { func checkLastNBitsUint64(x uint64, n uint8) bool { return x<<(64-n) == 0 } + +// BaseAddr gets the base address of the prefix +func (p *Prefix) BaseAddr() *IP { + if p.addr.isLegacy { + return p.baseAddr4() + } + + return p.baseAddr6() +} + +func (p *Prefix) baseAddr4() *IP { + addr := p.addr.copy() + + addr.lower = addr.lower >> (32 - p.pfxlen) + addr.lower = addr.lower << (32 - p.pfxlen) + + return addr +} + +func (p *Prefix) baseAddr6() *IP { + addr := p.addr.copy() + + if p.pfxlen <= 64 { + addr.lower = 0 + addr.higher = addr.higher >> (64 - p.pfxlen) + addr.higher = addr.higher << (64 - p.pfxlen) + } else { + addr.lower = addr.lower >> (128 - p.pfxlen) + addr.lower = addr.lower << (128 - p.pfxlen) + } + + return addr +} diff --git a/net/prefix_test.go b/net/prefix_test.go index 64bed942034b4e114e98be1a4e497ce19d8c072c..14eda7592b0d66f7967196169f07859490197528 100644 --- a/net/prefix_test.go +++ b/net/prefix_test.go @@ -657,3 +657,41 @@ func TestValid(t *testing.T) { assert.Equal(t, test.expected, p.Valid(), test.name) } } + +func TestBaseAddr(t *testing.T) { + tests := []struct { + name string + input *Prefix + expected *IP + }{ + { + name: "Test #1", + input: NewPfx(IPv4FromOctets(10, 1, 1, 0), 23).Dedup(), + expected: IPv4FromOctets(10, 1, 0, 0).Dedup(), + }, + { + name: "Test #2", + input: NewPfx(IPv4FromOctets(10, 1, 1, 2), 24).Dedup(), + expected: IPv4FromOctets(10, 1, 1, 0).Dedup(), + }, + { + name: "Test #3", + input: NewPfx(IPv6FromBlocks(10, 10, 20, 20, 1, 0, 0, 1), 64).Dedup(), + expected: IPv6FromBlocks(10, 10, 20, 20, 0, 0, 0, 0).Dedup(), + }, + { + name: "Test #4", + input: NewPfx(IPv6FromBlocks(10, 10, 20, 20, 1, 0, 0, 1), 48).Dedup(), + expected: IPv6FromBlocks(10, 10, 20, 0, 0, 0, 0, 0).Dedup(), + }, + { + name: "Test #5", + input: NewPfx(IPv6FromBlocks(10, 10, 20, 20, 1, 0, 5, 1), 126).Dedup(), + expected: IPv6FromBlocks(10, 10, 20, 20, 1, 0, 5, 0).Dedup(), + }, + } + + for _, test := range tests { + assert.Equal(t, test.expected, test.input.BaseAddr(), test.name) + } +}