diff --git a/protocols/bgp/server/bmp_router.go b/protocols/bgp/server/bmp_router.go index 7f572ed2f20319efbe24e09222f4d794f5931851..f6956069d33ad0f981227a0a9c026cb7e2dcba44 100644 --- a/protocols/bgp/server/bmp_router.go +++ b/protocols/bgp/server/bmp_router.go @@ -234,12 +234,8 @@ func (r *router) processPeerUpNotification(msg *bmppkt.PeerUpNotification) error } addrLen := net.IPv4len - for i := 0; i < net.IPv6len-net.IPv4len; i++ { - if msg.PerPeerHeader.PeerAddress[i] == 0 { - continue - } + if msg.PerPeerHeader.GetIPVersion() == 6 { addrLen = net.IPv6len - break } // bnet.IPFromBytes can only fail if length of argument is not 4 or 16. However, length is ensured here. diff --git a/protocols/bmp/packet/per_peer_header.go b/protocols/bmp/packet/per_peer_header.go index da797d44237177a3c6ef4972ab2d7cd55d9d665a..f7c60ed97f98068cf2bb573b348407584ef2d3e4 100644 --- a/protocols/bmp/packet/per_peer_header.go +++ b/protocols/bmp/packet/per_peer_header.go @@ -57,3 +57,11 @@ func decodePerPeerHeader(buf *bytes.Buffer) (*PerPeerHeader, error) { return p, nil } + +// GetIPVersion gets the IP version of the BGP session +func (p *PerPeerHeader) GetIPVersion() uint8 { + if p.PeerFlags>>7 == 1 { + return 6 + } + return 4 +} diff --git a/protocols/bmp/packet/per_peer_header_test.go b/protocols/bmp/packet/per_peer_header_test.go index 679946e76c19d05696bd26fb06c9b03dfb6a6847..8c51ee6c9b286faf927893d20ab093fe792ce0ce 100644 --- a/protocols/bmp/packet/per_peer_header_test.go +++ b/protocols/bmp/packet/per_peer_header_test.go @@ -113,4 +113,47 @@ func TestDecodePerPeerHeader(t *testing.T) { assert.Equalf(t, test.expected, p, "Test %q", test.name) } + +} + +func TestGetIPVersion(t *testing.T) { + tests := []struct { + name string + p *PerPeerHeader + expected uint8 + }{ + { + name: "IPv4", + p: &PerPeerHeader{ + PeerFlags: 0, + }, + expected: 4, + }, + { + name: "IPv4 #2", + p: &PerPeerHeader{ + PeerFlags: 127, + }, + expected: 4, + }, + { + name: "IPv6", + p: &PerPeerHeader{ + PeerFlags: 128, + }, + expected: 6, + }, + { + name: "IPv6 #2", + p: &PerPeerHeader{ + PeerFlags: 129, + }, + expected: 6, + }, + } + + for _, test := range tests { + v := test.p.GetIPVersion() + assert.Equal(t, test.expected, v) + } }