diff --git a/net/ip.go b/net/ip.go index b671dbb54ee2b80c83481de1546c9c79f806baab..34721c723e563643a7806bc43994f1b57280bed0 100644 --- a/net/ip.go +++ b/net/ip.go @@ -2,9 +2,11 @@ package net import ( "fmt" + "math" "net" api "github.com/bio-routing/bio-rd/net/api" + bmath "github.com/bio-routing/bio-rd/util/math" ) // IP represents an IPv4 or IPv6 address @@ -309,3 +311,32 @@ func (ip *IP) Next() *IP { return newIP } + +// MaskLastNBits masks the last n bits of an IP address +func (ip *IP) MaskLastNBits(n uint8) *IP { + ip = ip.copy() + + if ip.isLegacy { + ip.maskLastNBitsIPv4(n) + return ip + } + + ip.maskLastNBitsIPv6(n) + return ip +} + +func (ip *IP) maskLastNBitsIPv4(n uint8) { + mask := uint64((math.MaxUint64 << (n))) + ip.lower = ip.lower & mask +} + +func (ip *IP) maskLastNBitsIPv6(n uint8) { + maskBitsLow := uint8(bmath.Min(int(n), 64)) + maskBitsHigh := uint8(bmath.Max(int(n)-64, 0)) + + maskLow := uint64((math.MaxUint64 << (maskBitsLow))) + maskHigh := uint64((math.MaxUint64 << (maskBitsHigh))) + + ip.lower = ip.lower & maskLow + ip.higher = ip.higher & maskHigh +} diff --git a/net/ip_test.go b/net/ip_test.go index 28b3bbfb6a71bbda2bba59f2a09bfceb30930e4d..5a646b7261d7b30f27dc24d451c9a4174855cd20 100644 --- a/net/ip_test.go +++ b/net/ip_test.go @@ -617,3 +617,60 @@ func TestNext(t *testing.T) { assert.Equal(t, test.expected, test.input.Next(), test.name) } } + +func TestMaskLastNBits(t *testing.T) { + tests := []struct { + name string + input *IP + maskBits uint8 + expected *IP + }{ + { + name: "Test #1", + input: IPv4FromOctets(10, 1, 1, 1).Dedup(), + maskBits: 8, + expected: IPv4FromOctets(10, 1, 1, 0).Dedup(), + }, + { + name: "Test #2", + input: IPv4FromOctets(185, 65, 241, 123).Dedup(), + maskBits: 9, + expected: IPv4FromOctets(185, 65, 240, 0).Dedup(), + }, + { + name: "Test #3", + input: IPv4FromOctets(185, 65, 241, 123).Dedup(), + maskBits: 32, + expected: IPv4FromOctets(0, 0, 0, 0).Dedup(), + }, + { + name: "Test #4", + input: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x2222, 0x1111, 0x3333, 0xbbbb, 0xacab).Dedup(), + maskBits: 16, + expected: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x2222, 0x1111, 0x3333, 0xbbbb, 0x0000).Dedup(), + }, + { + name: "Test #5", + input: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x2222, 0x1111, 0x3333, 0xbbbb, 0xacab).Dedup(), + maskBits: 64, + expected: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x2222, 0x0000, 0x0000, 0x0000, 0x0000).Dedup(), + }, + { + name: "Test #6", + input: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x2222, 0x1111, 0x3333, 0xbbbb, 0xacab).Dedup(), + maskBits: 80, + expected: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000).Dedup(), + }, + { + name: "Test #7", + input: IPv6FromBlocks(0x2001, 0xaaaa, 0x1234, 0x2222, 0x1111, 0x3333, 0xbbbb, 0xacab).Dedup(), + maskBits: 128, + expected: IPv6FromBlocks(0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000).Dedup(), + }, + } + + for _, test := range tests { + res := test.input.MaskLastNBits(test.maskBits) + assert.Equal(t, test.expected, res, test.name) + } +} diff --git a/util/math/max.go b/util/math/max.go new file mode 100644 index 0000000000000000000000000000000000000000..7ab6f2adb9aa5f05ad42357a2e4be23fa0302439 --- /dev/null +++ b/util/math/max.go @@ -0,0 +1,9 @@ +package math + +// Max returns the maximum of a and b +func Max(a, b int) int { + if a > b { + return a + } + return b +}