From 2206c93b90f70da316783475dd4dd409d6760587 Mon Sep 17 00:00:00 2001
From: Daniel Czerwonk <daniel@dan-nrw.de>
Date: Thu, 5 Jul 2018 22:03:13 +0200
Subject: [PATCH] moved ribs for AFI/SAFI in own type

---
 protocols/bgp/server/family_routing.go        | 204 ++++++++++++++++++
 ...blished_test.go => family_routing_test.go} |   4 +-
 protocols/bgp/server/fsm.go                   |  15 +-
 protocols/bgp/server/fsm_established.go       | 165 +++-----------
 protocols/bgp/server/fsm_test.go              |   6 +-
 protocols/bgp/server/helper.go                |  14 ++
 protocols/bgp/server/update_sender.go         |  38 +++-
 protocols/bgp/server/update_sender_test.go    |  15 +-
 protocols/bgp/server/withdraw.go              |   7 +-
 protocols/bgp/server/withdraw_test.go         |   4 +-
 10 files changed, 302 insertions(+), 170 deletions(-)
 create mode 100644 protocols/bgp/server/family_routing.go
 rename protocols/bgp/server/{fsm_established_test.go => family_routing_test.go} (95%)
 create mode 100644 protocols/bgp/server/helper.go

diff --git a/protocols/bgp/server/family_routing.go b/protocols/bgp/server/family_routing.go
new file mode 100644
index 00000000..a5115480
--- /dev/null
+++ b/protocols/bgp/server/family_routing.go
@@ -0,0 +1,204 @@
+package server
+
+import (
+	"time"
+
+	bnet "github.com/bio-routing/bio-rd/net"
+	"github.com/bio-routing/bio-rd/protocols/bgp/packet"
+	"github.com/bio-routing/bio-rd/protocols/bgp/types"
+	"github.com/bio-routing/bio-rd/route"
+	"github.com/bio-routing/bio-rd/routingtable"
+	"github.com/bio-routing/bio-rd/routingtable/adjRIBIn"
+	"github.com/bio-routing/bio-rd/routingtable/adjRIBOut"
+	"github.com/bio-routing/bio-rd/routingtable/filter"
+	"github.com/bio-routing/bio-rd/routingtable/locRIB"
+)
+
+// familyRouting holds RIBs and the UpdateSender of an peer for an AFI/SAFI combination
+type familyRouting struct {
+	afi  uint16
+	safi uint8
+	fsm  *FSM
+
+	adjRIBIn  routingtable.RouteTableClient
+	adjRIBOut routingtable.RouteTableClient
+	rib       *locRIB.LocRIB
+
+	importFilter *filter.Filter
+	exportFilter *filter.Filter
+
+	updateSender *UpdateSender
+
+	initialized bool
+}
+
+func newFamilyRouting(afi uint16, safi uint8, rib *locRIB.LocRIB, fsm *FSM) *familyRouting {
+	return &familyRouting{
+		afi:  afi,
+		safi: safi,
+		rib:  rib,
+		fsm:  fsm,
+	}
+}
+
+func (f *familyRouting) init(n *routingtable.Neighbor) {
+	contributingASNs := f.rib.GetContributingASNs()
+
+	f.adjRIBIn = adjRIBIn.New(f.fsm.peer.importFilter, contributingASNs, f.fsm.peer.routerID, f.fsm.peer.clusterID)
+	contributingASNs.Add(f.fsm.peer.localASN)
+	f.adjRIBIn.Register(f.rib)
+
+	f.adjRIBOut = adjRIBOut.New(n, f.fsm.peer.exportFilter)
+	clientOptions := routingtable.ClientOptions{
+		BestOnly: true,
+	}
+	if f.fsm.options.AddPathRX {
+		clientOptions = f.fsm.peer.addPathSend
+	}
+
+	f.updateSender = newUpdateSender(f.fsm, f.afi, f.safi)
+	f.updateSender.Start(time.Millisecond * 5)
+
+	f.adjRIBOut.Register(f.updateSender)
+	f.rib.RegisterWithOptions(f.adjRIBOut, clientOptions)
+}
+
+func (f *familyRouting) dispose() {
+	if !f.initialized {
+		return
+	}
+
+	f.rib.GetContributingASNs().Remove(f.fsm.peer.localASN)
+	f.adjRIBIn.Unregister(f.rib)
+	f.rib.Unregister(f.adjRIBOut)
+	f.adjRIBOut.Unregister(f.updateSender)
+	f.updateSender.Destroy()
+
+	f.adjRIBIn = nil
+	f.adjRIBOut = nil
+
+	f.initialized = false
+}
+
+func (f *familyRouting) processUpdate(u *packet.BGPUpdate) {
+	if f.afi == packet.IPv4AFI && f.safi == packet.UnicastSAFI {
+		f.withdraws(u)
+		f.updates(u)
+	}
+
+	if f.fsm.options.SupportsMultiProtocol {
+		f.multiProtocolUpdates(u)
+	}
+}
+
+func (f *familyRouting) withdraws(u *packet.BGPUpdate) {
+	for r := u.WithdrawnRoutes; r != nil; r = r.Next {
+		pfx := bnet.NewPfx(bnet.IPv4(r.IP), r.Pfxlen)
+		f.adjRIBIn.RemovePath(pfx, nil)
+	}
+}
+
+func (f *familyRouting) updates(u *packet.BGPUpdate) {
+	for r := u.NLRI; r != nil; r = r.Next {
+		pfx := bnet.NewPfx(bnet.IPv4(r.IP), r.Pfxlen)
+
+		path := f.newRoutePath()
+		f.processAttributes(u.PathAttributes, path)
+
+		f.adjRIBIn.AddPath(pfx, path)
+	}
+}
+
+func (f *familyRouting) multiProtocolUpdates(u *packet.BGPUpdate) {
+	if !f.fsm.options.SupportsMultiProtocol {
+		return
+	}
+
+	path := f.newRoutePath()
+	f.processAttributes(u.PathAttributes, path)
+
+	for pa := u.PathAttributes; pa != nil; pa = pa.Next {
+		switch pa.TypeCode {
+		case packet.MultiProtocolReachNLRICode:
+			f.multiProtocolUpdate(path, pa.Value.(packet.MultiProtocolReachNLRI))
+		case packet.MultiProtocolUnreachNLRICode:
+			f.multiProtocolWithdraw(path, pa.Value.(packet.MultiProtocolUnreachNLRI))
+		}
+	}
+}
+
+func (f *familyRouting) newRoutePath() *route.Path {
+	return &route.Path{
+		Type: route.BGPPathType,
+		BGPPath: &route.BGPPath{
+			Source: f.fsm.peer.addr,
+			EBGP:   f.fsm.peer.localASN != f.fsm.peer.peerASN,
+		},
+	}
+}
+
+func (f *familyRouting) multiProtocolUpdate(path *route.Path, nlri packet.MultiProtocolReachNLRI) {
+	path.BGPPath.NextHop = nlri.NextHop
+
+	for _, pfx := range nlri.Prefixes {
+		f.adjRIBIn.AddPath(pfx, path)
+	}
+}
+
+func (f *familyRouting) multiProtocolWithdraw(path *route.Path, nlri packet.MultiProtocolUnreachNLRI) {
+	for _, pfx := range nlri.Prefixes {
+		f.adjRIBIn.RemovePath(pfx, path)
+	}
+}
+
+func (f *familyRouting) processAttributes(attrs *packet.PathAttribute, path *route.Path) {
+	for pa := attrs; pa != nil; pa = pa.Next {
+		switch pa.TypeCode {
+		case packet.OriginAttr:
+			path.BGPPath.Origin = pa.Value.(uint8)
+		case packet.LocalPrefAttr:
+			path.BGPPath.LocalPref = pa.Value.(uint32)
+		case packet.MEDAttr:
+			path.BGPPath.MED = pa.Value.(uint32)
+		case packet.NextHopAttr:
+			path.BGPPath.NextHop = pa.Value.(bnet.IP)
+		case packet.ASPathAttr:
+			path.BGPPath.ASPath = pa.Value.(types.ASPath)
+			path.BGPPath.ASPathLen = path.BGPPath.ASPath.Length()
+		case packet.AggregatorAttr:
+			aggr := pa.Value.(types.Aggregator)
+			path.BGPPath.Aggregator = &aggr
+		case packet.AtomicAggrAttr:
+			path.BGPPath.AtomicAggregate = true
+		case packet.CommunitiesAttr:
+			path.BGPPath.Communities = pa.Value.([]uint32)
+		case packet.LargeCommunitiesAttr:
+			path.BGPPath.LargeCommunities = pa.Value.([]types.LargeCommunity)
+		case packet.OriginatorIDAttr:
+			path.BGPPath.OriginatorID = pa.Value.(uint32)
+		case packet.ClusterListAttr:
+			path.BGPPath.ClusterList = pa.Value.([]uint32)
+		default:
+			unknownAttr := f.processUnknownAttribute(pa)
+			if unknownAttr != nil {
+				path.BGPPath.UnknownAttributes = append(path.BGPPath.UnknownAttributes, *unknownAttr)
+			}
+		}
+	}
+}
+
+func (f *familyRouting) processUnknownAttribute(attr *packet.PathAttribute) *types.UnknownPathAttribute {
+	if !attr.Transitive {
+		return nil
+	}
+
+	u := &types.UnknownPathAttribute{
+		Transitive: true,
+		Optional:   attr.Optional,
+		Partial:    attr.Partial,
+		TypeCode:   attr.TypeCode,
+		Value:      attr.Value.([]byte),
+	}
+
+	return u
+}
diff --git a/protocols/bgp/server/fsm_established_test.go b/protocols/bgp/server/family_routing_test.go
similarity index 95%
rename from protocols/bgp/server/fsm_established_test.go
rename to protocols/bgp/server/family_routing_test.go
index 61103d99..102c2e05 100644
--- a/protocols/bgp/server/fsm_established_test.go
+++ b/protocols/bgp/server/family_routing_test.go
@@ -43,12 +43,12 @@ func TestProcessAttributes(t *testing.T) {
 		Next: unknown1,
 	}
 
-	e := &establishedState{}
+	f := &familyRouting{}
 
 	p := &route.Path{
 		BGPPath: &route.BGPPath{},
 	}
-	e.processAttributes(asPath, p)
+	f.processAttributes(asPath, p)
 
 	expectedCodes := []uint8{200, 100}
 	expectedValues := [][]byte{[]byte{5, 6}, []byte{1, 2, 3, 4}}
diff --git a/protocols/bgp/server/fsm.go b/protocols/bgp/server/fsm.go
index 991f1b60..05a83eec 100644
--- a/protocols/bgp/server/fsm.go
+++ b/protocols/bgp/server/fsm.go
@@ -8,8 +8,6 @@ import (
 
 	"github.com/bio-routing/bio-rd/protocols/bgp/packet"
 	"github.com/bio-routing/bio-rd/protocols/bgp/types"
-	"github.com/bio-routing/bio-rd/routingtable"
-	"github.com/bio-routing/bio-rd/routingtable/locRIB"
 	log "github.com/sirupsen/logrus"
 )
 
@@ -60,10 +58,8 @@ type FSM struct {
 	local net.IP
 
 	ribsInitialized bool
-	adjRIBIn        routingtable.RouteTableClient
-	adjRIBOut       routingtable.RouteTableClient
-	rib             *locRIB.LocRIB
-	updateSender    *UpdateSender
+	ipv4Unicast     *familyRouting
+	ipv6Unicast     *familyRouting
 
 	neighborID uint32
 	state      state
@@ -89,7 +85,7 @@ func NewActiveFSM2(peer *peer) *FSM {
 }
 
 func newFSM2(peer *peer) *FSM {
-	return &FSM{
+	f := &FSM{
 		connectRetryTime: time.Minute,
 		peer:             peer,
 		eventCh:          make(chan int),
@@ -99,9 +95,12 @@ func newFSM2(peer *peer) *FSM {
 		msgRecvCh:        make(chan []byte),
 		msgRecvFailCh:    make(chan error),
 		stopMsgRecvCh:    make(chan struct{}),
-		rib:              peer.rib,
 		options:          &types.Options{},
 	}
+	f.ipv4Unicast = newFamilyRouting(packet.IPv4AFI, packet.UnicastSAFI, peer.rib, f)
+	f.ipv6Unicast = newFamilyRouting(packet.IPv6AFI, packet.UnicastSAFI, peer.rib, f)
+
+	return f
 }
 
 func (fsm *FSM) start() {
diff --git a/protocols/bgp/server/fsm_established.go b/protocols/bgp/server/fsm_established.go
index 2baf46ab..78f7f194 100644
--- a/protocols/bgp/server/fsm_established.go
+++ b/protocols/bgp/server/fsm_established.go
@@ -4,15 +4,11 @@ import (
 	"bytes"
 	"fmt"
 	"net"
-	"time"
 
 	bnet "github.com/bio-routing/bio-rd/net"
 	"github.com/bio-routing/bio-rd/protocols/bgp/packet"
-	"github.com/bio-routing/bio-rd/protocols/bgp/types"
 	"github.com/bio-routing/bio-rd/route"
 	"github.com/bio-routing/bio-rd/routingtable"
-	"github.com/bio-routing/bio-rd/routingtable/adjRIBIn"
-	"github.com/bio-routing/bio-rd/routingtable/adjRIBOut"
 )
 
 type establishedState struct {
@@ -57,12 +53,6 @@ func (s establishedState) run() (state, string) {
 }
 
 func (s *establishedState) init() error {
-	contributingASNs := s.fsm.rib.GetContributingASNs()
-
-	s.fsm.adjRIBIn = adjRIBIn.New(s.fsm.peer.importFilter, contributingASNs, s.fsm.peer.routerID, s.fsm.peer.clusterID)
-	contributingASNs.Add(s.fsm.peer.localASN)
-	s.fsm.adjRIBIn.Register(s.fsm.rib)
-
 	host, _, err := net.SplitHostPort(s.fsm.con.LocalAddr().String())
 	if err != nil {
 		return fmt.Errorf("Unable to get local address: %v", err)
@@ -88,35 +78,15 @@ func (s *establishedState) init() error {
 		ClusterID:            s.fsm.peer.clusterID,
 	}
 
-	s.fsm.adjRIBOut = adjRIBOut.New(n, s.fsm.peer.exportFilter)
-	clientOptions := routingtable.ClientOptions{
-		BestOnly: true,
-	}
-	if s.fsm.options.AddPathRX {
-		clientOptions = s.fsm.peer.addPathSend
-	}
-
-	s.fsm.updateSender = newUpdateSender(s.fsm)
-	s.fsm.updateSender.Start(time.Millisecond * 5)
-
-	s.fsm.adjRIBOut.Register(s.fsm.updateSender)
-	s.fsm.rib.RegisterWithOptions(s.fsm.adjRIBOut, clientOptions)
-
+	s.fsm.ipv4Unicast.init(n)
+	s.fsm.ipv6Unicast.init(n)
 	s.fsm.ribsInitialized = true
 	return nil
 }
 
 func (s *establishedState) uninit() {
-	s.fsm.rib.GetContributingASNs().Remove(s.fsm.peer.localASN)
-	s.fsm.adjRIBIn.Unregister(s.fsm.rib)
-	s.fsm.rib.Unregister(s.fsm.adjRIBOut)
-	s.fsm.adjRIBOut.Unregister(s.fsm.updateSender)
-	s.fsm.updateSender.Destroy()
-
-	s.fsm.adjRIBIn = nil
-	s.fsm.adjRIBOut = nil
-
-	s.fsm.ribsInitialized = false
+	s.fsm.ipv4Unicast.dispose()
+	s.fsm.ipv6Unicast.dispose()
 }
 
 func (s *establishedState) manualStop() (state, string) {
@@ -205,123 +175,44 @@ func (s *establishedState) update(msg *packet.BGPMessage) (state, string) {
 	}
 
 	u := msg.Body.(*packet.BGPUpdate)
-	s.withdraws(u)
-	s.updates(u)
-	s.multiProtocolUpdates(u)
+	afi, safi := s.addressFamilyForUpdate(u)
 
-	return newEstablishedState(s.fsm), s.fsm.reason
-}
-
-func (s *establishedState) withdraws(u *packet.BGPUpdate) {
-	for r := u.WithdrawnRoutes; r != nil; r = r.Next {
-		pfx := bnet.NewPfx(bnet.IPv4(r.IP), r.Pfxlen)
-		s.fsm.adjRIBIn.RemovePath(pfx, nil)
+	if safi != packet.UnicastSAFI {
+		// only unicast support, so other SAFIs are ignored
+		return newEstablishedState(s.fsm), s.fsm.reason
 	}
-}
-
-func (s *establishedState) updates(u *packet.BGPUpdate) {
-	for r := u.NLRI; r != nil; r = r.Next {
-		pfx := bnet.NewPfx(bnet.IPv4(r.IP), r.Pfxlen)
-
-		path := s.newRoutePath()
-		s.processAttributes(u.PathAttributes, path)
-
-		s.fsm.adjRIBIn.AddPath(pfx, path)
-	}
-}
 
-func (s *establishedState) multiProtocolUpdates(u *packet.BGPUpdate) {
-	if !s.fsm.options.SupportsMultiProtocol {
-		return
+	switch afi {
+	case packet.IPv4AFI:
+		s.fsm.ipv4Unicast.processUpdate(u)
+	case packet.IPv6AFI:
+		s.fsm.ipv6Unicast.processUpdate(u)
 	}
 
-	path := s.newRoutePath()
-	s.processAttributes(u.PathAttributes, path)
-
-	for pa := u.PathAttributes; pa != nil; pa = pa.Next {
-		switch pa.TypeCode {
-		case packet.MultiProtocolReachNLRICode:
-			s.multiProtocolUpdate(path, pa.Value.(packet.MultiProtocolReachNLRI))
-		case packet.MultiProtocolUnreachNLRICode:
-			s.multiProtocolWithdraw(path, pa.Value.(packet.MultiProtocolUnreachNLRI))
-		}
-	}
+	return newEstablishedState(s.fsm), s.fsm.reason
 }
 
-func (s *establishedState) newRoutePath() *route.Path {
-	return &route.Path{
-		Type: route.BGPPathType,
-		BGPPath: &route.BGPPath{
-			Source: s.fsm.peer.addr,
-			EBGP:   s.fsm.peer.localASN != s.fsm.peer.peerASN,
-		},
+func (s *establishedState) addressFamilyForUpdate(u *packet.BGPUpdate) (afi uint16, safi uint8) {
+	if !s.fsm.options.SupportsMultiProtocol || u.NLRI != nil || u.WithdrawnRoutes != nil {
+		return packet.IPv4AFI, packet.UnicastSAFI
 	}
-}
 
-func (s *establishedState) multiProtocolUpdate(path *route.Path, nlri packet.MultiProtocolReachNLRI) {
-	path.BGPPath.NextHop = nlri.NextHop
+	cur := u.PathAttributes
+	for cur != nil {
+		cur = cur.Next
 
-	for _, pfx := range nlri.Prefixes {
-		s.fsm.adjRIBIn.AddPath(pfx, path)
-	}
-}
-
-func (s *establishedState) multiProtocolWithdraw(path *route.Path, nlri packet.MultiProtocolUnreachNLRI) {
-	for _, pfx := range nlri.Prefixes {
-		s.fsm.adjRIBIn.RemovePath(pfx, path)
-	}
-}
-
-func (s *establishedState) processAttributes(attrs *packet.PathAttribute, path *route.Path) {
-	for pa := attrs; pa != nil; pa = pa.Next {
-		switch pa.TypeCode {
-		case packet.OriginAttr:
-			path.BGPPath.Origin = pa.Value.(uint8)
-		case packet.LocalPrefAttr:
-			path.BGPPath.LocalPref = pa.Value.(uint32)
-		case packet.MEDAttr:
-			path.BGPPath.MED = pa.Value.(uint32)
-		case packet.NextHopAttr:
-			path.BGPPath.NextHop = pa.Value.(bnet.IP)
-		case packet.ASPathAttr:
-			path.BGPPath.ASPath = pa.Value.(types.ASPath)
-			path.BGPPath.ASPathLen = path.BGPPath.ASPath.Length()
-		case packet.AggregatorAttr:
-			aggr := pa.Value.(types.Aggregator)
-			path.BGPPath.Aggregator = &aggr
-		case packet.AtomicAggrAttr:
-			path.BGPPath.AtomicAggregate = true
-		case packet.CommunitiesAttr:
-			path.BGPPath.Communities = pa.Value.([]uint32)
-		case packet.LargeCommunitiesAttr:
-			path.BGPPath.LargeCommunities = pa.Value.([]types.LargeCommunity)
-		case packet.OriginatorIDAttr:
-			path.BGPPath.OriginatorID = pa.Value.(uint32)
-		case packet.ClusterListAttr:
-			path.BGPPath.ClusterList = pa.Value.([]uint32)
-		default:
-			unknownAttr := s.processUnknownAttribute(pa)
-			if unknownAttr != nil {
-				path.BGPPath.UnknownAttributes = append(path.BGPPath.UnknownAttributes, *unknownAttr)
-			}
+		if cur.TypeCode == packet.MultiProtocolReachNLRICode {
+			a := cur.Value.(packet.MultiProtocolReachNLRI)
+			return a.AFI, a.SAFI
 		}
-	}
-}
-
-func (s *establishedState) processUnknownAttribute(attr *packet.PathAttribute) *types.UnknownPathAttribute {
-	if !attr.Transitive {
-		return nil
-	}
 
-	u := &types.UnknownPathAttribute{
-		Transitive: true,
-		Optional:   attr.Optional,
-		Partial:    attr.Partial,
-		TypeCode:   attr.TypeCode,
-		Value:      attr.Value.([]byte),
+		if cur.TypeCode == packet.MultiProtocolUnreachNLRICode {
+			a := cur.Value.(packet.MultiProtocolUnreachNLRI)
+			return a.AFI, a.SAFI
+		}
 	}
 
-	return u
+	return
 }
 
 func (s *establishedState) keepaliveReceived() (state, string) {
diff --git a/protocols/bgp/server/fsm_test.go b/protocols/bgp/server/fsm_test.go
index 5cca8d85..e39d3d6a 100644
--- a/protocols/bgp/server/fsm_test.go
+++ b/protocols/bgp/server/fsm_test.go
@@ -94,7 +94,7 @@ func TestFSM100Updates(t *testing.T) {
 	}
 
 	time.Sleep(time.Second)
-	ribRouteCount := fsmA.rib.RouteCount()
+	ribRouteCount := fsmA.ipv4Unicast.rib.RouteCount()
 	if ribRouteCount != 255 {
 		t.Errorf("Unexpected route count in LocRIB: %d", ribRouteCount)
 	}
@@ -112,11 +112,11 @@ func TestFSM100Updates(t *testing.T) {
 			0, 0,
 		}
 		fsmA.msgRecvCh <- update
-		ribRouteCount = fsmA.rib.RouteCount()
+		ribRouteCount = fsmA.ipv4Unicast.rib.RouteCount()
 	}
 	time.Sleep(time.Second * 1)
 
-	ribRouteCount = fsmA.rib.RouteCount()
+	ribRouteCount = fsmA.ipv4Unicast.rib.RouteCount()
 	if ribRouteCount != 0 {
 		t.Errorf("Unexpected route count in LocRIB: %d", ribRouteCount)
 	}
diff --git a/protocols/bgp/server/helper.go b/protocols/bgp/server/helper.go
new file mode 100644
index 00000000..7a7255f0
--- /dev/null
+++ b/protocols/bgp/server/helper.go
@@ -0,0 +1,14 @@
+package server
+
+import (
+	bnet "github.com/bio-routing/bio-rd/net"
+	"github.com/bio-routing/bio-rd/protocols/bgp/packet"
+)
+
+func afiForPrefix(pfx bnet.Prefix) uint16 {
+	if pfx.Addr().IsIPv4() {
+		return packet.IPv6AFI
+	}
+
+	return packet.IPv6AFI
+}
diff --git a/protocols/bgp/server/update_sender.go b/protocols/bgp/server/update_sender.go
index ff82ba88..bc0f204a 100644
--- a/protocols/bgp/server/update_sender.go
+++ b/protocols/bgp/server/update_sender.go
@@ -16,6 +16,8 @@ import (
 type UpdateSender struct {
 	routingtable.ClientManager
 	fsm       *FSM
+	afi       uint16
+	safi      uint8
 	iBGP      bool
 	rrClient  bool
 	toSendMu  sync.Mutex
@@ -28,9 +30,11 @@ type pathPfxs struct {
 	pfxs []bnet.Prefix
 }
 
-func newUpdateSender(fsm *FSM) *UpdateSender {
+func newUpdateSender(fsm *FSM, afi uint16, safi uint8) *UpdateSender {
 	return &UpdateSender{
 		fsm:       fsm,
+		afi:       afi,
+		safi:      safi,
 		iBGP:      fsm.peer.localASN == fsm.peer.peerASN,
 		rrClient:  fsm.peer.routeReflectorClient,
 		destroyCh: make(chan struct{}),
@@ -90,7 +94,7 @@ func (u *UpdateSender) sender(aggrTime time.Duration) {
 
 		for key, pathNLRIs := range u.toSend {
 			budget = packet.MaxLen - packet.HeaderLen - packet.MinUpdateLen - int(pathNLRIs.path.BGPPath.Length()) - overhead
-      
+
 			pathAttrs, err = packet.PathAttributes(pathNLRIs.path, u.iBGP, u.rrClient)
 			if err != nil {
 				log.Errorf("Unable to get path attributes: %v", err)
@@ -125,19 +129,27 @@ func (u *UpdateSender) sender(aggrTime time.Duration) {
 }
 
 func (u *UpdateSender) updateOverhead() int {
-	// TODO: for multi RIB support we need the AFI/SAFI combination to determine overhead. For now: MultiProtocol = IPv6
-	if u.fsm.options.SupportsMultiProtocol {
-		// since we are replacing the next hop attribute IPv4Len has to be subtracted, we also add another byte for extended length
-		return packet.AFILen + packet.SAFILen + 1 + packet.IPv6Len - packet.IPv4Len + 1
+	if !u.fsm.options.SupportsMultiProtocol {
+		return 0
 	}
 
-	return 0
+	addrLen := packet.IPv4AFI
+	if u.afi == packet.IPv6AFI {
+		addrLen = packet.IPv6Len
+	}
+
+	// since we are replacing the next hop attribute IPv4Len has to be subtracted, we also add another byte for extended length
+	return packet.AFILen + packet.SAFILen + 1 + addrLen - packet.IPv4Len + 1
 }
 
 func (u *UpdateSender) sendUpdates(pathAttrs *packet.PathAttribute, updatePrefixes [][]bnet.Prefix, pathID uint32) {
 	var err error
 	for _, prefixes := range updatePrefixes {
 		update := u.updateMessageForPrefixes(prefixes, pathAttrs, pathID)
+		if update == nil {
+			log.Errorf("Failed to create update: Neighbor does not support multi protocol.")
+			return
+		}
 
 		err = serializeAndSendUpdate(u.fsm.con, update, u.fsm.options)
 		if err != nil {
@@ -147,11 +159,15 @@ func (u *UpdateSender) sendUpdates(pathAttrs *packet.PathAttribute, updatePrefix
 }
 
 func (u *UpdateSender) updateMessageForPrefixes(pfxs []bnet.Prefix, pa *packet.PathAttribute, pathID uint32) *packet.BGPUpdate {
+	if u.afi == packet.IPv4AFI && u.safi == packet.UnicastSAFI {
+		return u.bgpUpdate(pfxs, pa, pathID)
+	}
+
 	if u.fsm.options.SupportsMultiProtocol {
 		return u.bgpUpdateMultiProtocol(pfxs, pa, pathID)
 	}
 
-	return u.bgpUpdate(pfxs, pa, pathID)
+	return nil
 }
 
 func (u *UpdateSender) bgpUpdate(pfxs []bnet.Prefix, pa *packet.PathAttribute, pathID uint32) *packet.BGPUpdate {
@@ -179,8 +195,8 @@ func (u *UpdateSender) bgpUpdateMultiProtocol(pfxs []bnet.Prefix, pa *packet.Pat
 	attrs := &packet.PathAttribute{
 		TypeCode: packet.MultiProtocolReachNLRICode,
 		Value: packet.MultiProtocolReachNLRI{
-			AFI:      packet.IPv6AFI,
-			SAFI:     packet.UnicastSAFI,
+			AFI:      u.afi,
+			SAFI:     u.safi,
 			NextHop:  nextHop,
 			Prefixes: pfxs,
 		},
@@ -228,7 +244,7 @@ func (u *UpdateSender) RemovePath(pfx bnet.Prefix, p *route.Path) bool {
 
 func (u *UpdateSender) withdrawPrefix(pfx bnet.Prefix, p *route.Path) error {
 	if u.fsm.options.SupportsMultiProtocol {
-		return withDrawPrefixesMultiProtocol(u.fsm.con, u.fsm.options, pfx)
+		return withDrawPrefixesMultiProtocol(u.fsm.con, u.fsm.options, pfx, u.afi, u.safi)
 	}
 
 	return withDrawPrefixesAddPath(u.fsm.con, u.fsm.options, pfx, p)
diff --git a/protocols/bgp/server/update_sender_test.go b/protocols/bgp/server/update_sender_test.go
index 5c934fd2..646bf5d0 100644
--- a/protocols/bgp/server/update_sender_test.go
+++ b/protocols/bgp/server/update_sender_test.go
@@ -5,6 +5,8 @@ import (
 	"testing"
 	"time"
 
+	"github.com/bio-routing/bio-rd/protocols/bgp/packet"
+
 	"github.com/stretchr/testify/assert"
 
 	bnet "github.com/bio-routing/bio-rd/net"
@@ -22,10 +24,11 @@ func TestSender(t *testing.T) {
 		generateNLRIs   uint64
 		expectedUpdates [][]byte
 		addPath         bool
-		ipv6            bool
+		afi             uint16
 	}{
 		{
 			name: "Two paths with 3 NLRIs each",
+			afi:  packet.IPv4AFI,
 			paths: []pathPfxs{
 				{
 					path: &route.Path{
@@ -87,6 +90,7 @@ func TestSender(t *testing.T) {
 		},
 		{
 			name:    "Two paths with 3 NLRIs each with BGP Add Path",
+			afi:     packet.IPv4AFI,
 			addPath: true,
 			paths: []pathPfxs{
 				{
@@ -165,6 +169,7 @@ func TestSender(t *testing.T) {
 		},
 		{
 			name: "Overflow. Too many NLRIs.",
+			afi:  packet.IPv4AFI,
 			paths: []pathPfxs{
 				{
 					path: &route.Path{
@@ -332,6 +337,7 @@ func TestSender(t *testing.T) {
 		},
 		{
 			name: "Overflow with IPv6. Too many NLRIs.",
+			afi:  packet.IPv6AFI,
 			paths: []pathPfxs{
 				{
 					path: &route.Path{
@@ -857,7 +863,6 @@ func TestSender(t *testing.T) {
 					0x0, 0x40, 0x1, 0x1, 0x0, 0x40, 0x5, 0x4, 0x0, 0x0, 0x0, 0x64,
 				},
 			},
-			ipv6: true,
 		},
 	}
 
@@ -881,11 +886,11 @@ func TestSender(t *testing.T) {
 			}
 		}
 
-		if test.ipv6 {
+		if test.afi == packet.IPv6AFI {
 			fsmA.options.SupportsMultiProtocol = true
 		}
 
-		updateSender := newUpdateSender(fsmA)
+		updateSender := newUpdateSender(fsmA, test.afi, packet.UnicastSAFI)
 
 		for _, pathPfx := range test.paths {
 			for _, pfx := range pathPfx.pfxs {
@@ -897,7 +902,7 @@ func TestSender(t *testing.T) {
 					y := i - x
 
 					var pfx bnet.Prefix
-					if test.ipv6 {
+					if test.afi == packet.IPv6AFI {
 						pfx = bnet.NewPfx(bnet.IPv6FromBlocks(0x2001, 0x678, 0x1e0, 0, 0, 0, 0, 0), 48)
 					} else {
 						pfx = bnet.NewPfx(bnet.IPv4FromOctets(10, 0, uint8(x), uint8(y)), 32)
diff --git a/protocols/bgp/server/withdraw.go b/protocols/bgp/server/withdraw.go
index 013f6ce3..2ab554b6 100644
--- a/protocols/bgp/server/withdraw.go
+++ b/protocols/bgp/server/withdraw.go
@@ -59,16 +59,17 @@ func withDrawPrefixesAddPath(out io.Writer, opt *types.Options, pfx net.Prefix,
 	return serializeAndSendUpdate(out, update, opt)
 }
 
-func withDrawPrefixesMultiProtocol(out io.Writer, opt *types.Options, pfx net.Prefix) error {
+func withDrawPrefixesMultiProtocol(out io.Writer, opt *types.Options, pfx net.Prefix, afi uint16, safi uint8) error {
 	update := &packet.BGPUpdate{
 		PathAttributes: &packet.PathAttribute{
 			TypeCode: packet.MultiProtocolUnreachNLRICode,
 			Value: packet.MultiProtocolUnreachNLRI{
-				AFI:      packet.IPv6AFI,
-				SAFI:     packet.UnicastSAFI,
+				AFI:      afi,
+				SAFI:     safi,
 				Prefixes: []net.Prefix{pfx},
 			},
 		},
 	}
+
 	return serializeAndSendUpdate(out, update, opt)
 }
diff --git a/protocols/bgp/server/withdraw_test.go b/protocols/bgp/server/withdraw_test.go
index 393be032..693cbb4e 100644
--- a/protocols/bgp/server/withdraw_test.go
+++ b/protocols/bgp/server/withdraw_test.go
@@ -3,6 +3,8 @@ package server
 import (
 	"testing"
 
+	"github.com/bio-routing/bio-rd/protocols/bgp/packet"
+
 	"github.com/bio-routing/bio-rd/protocols/bgp/types"
 
 	"errors"
@@ -91,7 +93,7 @@ func TestWithDrawPrefixesMultiProtocol(t *testing.T) {
 			opt := &types.Options{
 				AddPathRX: false,
 			}
-			err := withDrawPrefixesMultiProtocol(buf, opt, test.Prefix)
+			err := withDrawPrefixesMultiProtocol(buf, opt, test.Prefix, packet.IPv6AFI, packet.UnicastSAFI)
 			if err != nil {
 				t.Fatalf("unexpected error: %v", err)
 			}
-- 
GitLab