From 64f9244a5173f5a666432ddb830d5f2400ddc98d Mon Sep 17 00:00:00 2001
From: Oliver Herms <oliver.herms@exaring.de>
Date: Mon, 22 Oct 2018 14:15:40 +0200
Subject: [PATCH] Cleanup IP version selection

---
 protocols/bgp/server/bmp_router.go           |  6 +--
 protocols/bmp/packet/per_peer_header.go      |  8 ++++
 protocols/bmp/packet/per_peer_header_test.go | 43 ++++++++++++++++++++
 3 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/protocols/bgp/server/bmp_router.go b/protocols/bgp/server/bmp_router.go
index 7f572ed2..f6956069 100644
--- a/protocols/bgp/server/bmp_router.go
+++ b/protocols/bgp/server/bmp_router.go
@@ -234,12 +234,8 @@ func (r *router) processPeerUpNotification(msg *bmppkt.PeerUpNotification) error
 	}
 
 	addrLen := net.IPv4len
-	for i := 0; i < net.IPv6len-net.IPv4len; i++ {
-		if msg.PerPeerHeader.PeerAddress[i] == 0 {
-			continue
-		}
+	if msg.PerPeerHeader.GetIPVersion() == 6 {
 		addrLen = net.IPv6len
-		break
 	}
 
 	// bnet.IPFromBytes can only fail if length of argument is not 4 or 16. However, length is ensured here.
diff --git a/protocols/bmp/packet/per_peer_header.go b/protocols/bmp/packet/per_peer_header.go
index da797d44..f7c60ed9 100644
--- a/protocols/bmp/packet/per_peer_header.go
+++ b/protocols/bmp/packet/per_peer_header.go
@@ -57,3 +57,11 @@ func decodePerPeerHeader(buf *bytes.Buffer) (*PerPeerHeader, error) {
 
 	return p, nil
 }
+
+// GetIPVersion gets the IP version of the BGP session
+func (p *PerPeerHeader) GetIPVersion() uint8 {
+	if p.PeerFlags>>7 == 1 {
+		return 6
+	}
+	return 4
+}
diff --git a/protocols/bmp/packet/per_peer_header_test.go b/protocols/bmp/packet/per_peer_header_test.go
index 679946e7..8c51ee6c 100644
--- a/protocols/bmp/packet/per_peer_header_test.go
+++ b/protocols/bmp/packet/per_peer_header_test.go
@@ -113,4 +113,47 @@ func TestDecodePerPeerHeader(t *testing.T) {
 
 		assert.Equalf(t, test.expected, p, "Test %q", test.name)
 	}
+
+}
+
+func TestGetIPVersion(t *testing.T) {
+	tests := []struct {
+		name     string
+		p        *PerPeerHeader
+		expected uint8
+	}{
+		{
+			name: "IPv4",
+			p: &PerPeerHeader{
+				PeerFlags: 0,
+			},
+			expected: 4,
+		},
+		{
+			name: "IPv4 #2",
+			p: &PerPeerHeader{
+				PeerFlags: 127,
+			},
+			expected: 4,
+		},
+		{
+			name: "IPv6",
+			p: &PerPeerHeader{
+				PeerFlags: 128,
+			},
+			expected: 6,
+		},
+		{
+			name: "IPv6 #2",
+			p: &PerPeerHeader{
+				PeerFlags: 129,
+			},
+			expected: 6,
+		},
+	}
+
+	for _, test := range tests {
+		v := test.p.GetIPVersion()
+		assert.Equal(t, test.expected, v)
+	}
 }
-- 
GitLab