diff --git a/protocols/bgp/server/fsm_address_family.go b/protocols/bgp/server/fsm_address_family.go index bd3e05d599d921b5f13cd0aec5f14c6df0373735..04234c3d604ad8a214b9de73f15c4254c26de290 100644 --- a/protocols/bgp/server/fsm_address_family.go +++ b/protocols/bgp/server/fsm_address_family.go @@ -66,12 +66,13 @@ func (f *fsmAddressFamily) init(n *routingtable.Neighbor) { f.adjRIBOut = adjRIBOut.New(n, f.exportFilter, !f.addPathTX.BestOnly) - f.updateSender = newUpdateSender(f.fsm, f.afi, f.safi) + f.updateSender = newUpdateSender(f) f.updateSender.Start(time.Millisecond * 5) f.adjRIBOut.Register(f.updateSender) f.rib.RegisterWithOptions(f.adjRIBOut, f.addPathTX) + f.initialized = true } func (f *fsmAddressFamily) bmpInit() { diff --git a/protocols/bgp/server/fsm_address_family_test.go b/protocols/bgp/server/fsm_address_family_test.go index cea9ca10cabf02d630a8f1ef00b549382c1dd5e7..9c885b2bd970c1ab87282a48881cf881e570d309 100644 --- a/protocols/bgp/server/fsm_address_family_test.go +++ b/protocols/bgp/server/fsm_address_family_test.go @@ -1,14 +1,71 @@ package server import ( + "sync" "testing" "github.com/bio-routing/bio-rd/protocols/bgp/packet" "github.com/bio-routing/bio-rd/protocols/bgp/types" "github.com/bio-routing/bio-rd/route" + "github.com/bio-routing/bio-rd/routingtable" + "github.com/bio-routing/bio-rd/routingtable/filter" + "github.com/bio-routing/bio-rd/routingtable/locRIB" "github.com/stretchr/testify/assert" ) +func TestFSMAFIInitDispose(t *testing.T) { + f := &fsmAddressFamily{ + afi: packet.IPv4AFI, + safi: packet.UnicastSAFI, + rib: locRIB.New("inet.0"), + importFilter: filter.NewAcceptAllFilter(), + exportFilter: filter.NewAcceptAllFilter(), + fsm: &FSM{ + peer: &peer{ + routerID: 100, + localASN: 15169, + }, + }, + addPathTX: routingtable.ClientOptions{ + BestOnly: true, + }, + } + + n := &routingtable.Neighbor{ + LocalASN: 15169, + } + + assert.Equal(t, uint64(0), f.rib.ClientCount()) + + f.init(n) + assert.NotEqual(t, nil, f.adjRIBIn) + assert.Equal(t, true, f.rib.GetContributingASNs().IsContributingASN(15169)) + assert.NotEqual(t, true, f.rib.GetContributingASNs().IsContributingASN(15170)) + + assert.NotEqual(t, nil, f.adjRIBOut) + assert.NotEqual(t, nil, f.updateSender) + + assert.Equal(t, uint64(1), f.adjRIBIn.ClientCount()) + assert.Equal(t, uint64(1), f.rib.ClientCount()) + + wg := sync.WaitGroup{} + wg.Add(1) + assert.Equal(t, wg, f.updateSender.wg) + + assert.Equal(t, uint64(1), f.adjRIBOut.ClientCount()) + + assert.Equal(t, true, f.initialized) + + // Dispose + f.dispose() + + f.updateSender.wg.Wait() + assert.Equal(t, false, f.rib.GetContributingASNs().IsContributingASN(15169)) + assert.Equal(t, uint64(0), f.rib.ClientCount()) + assert.Equal(t, nil, f.adjRIBOut) + assert.Equal(t, false, f.initialized) +} + func TestProcessAttributes(t *testing.T) { unknown3 := &packet.PathAttribute{ Transitive: true, diff --git a/protocols/bgp/server/update_sender.go b/protocols/bgp/server/update_sender.go index 692fc0b3483c67a01200eb5ffc9621a02a019241..c43d1778ea225b7586c9e2d442832b1f4785d6e0 100644 --- a/protocols/bgp/server/update_sender.go +++ b/protocols/bgp/server/update_sender.go @@ -25,6 +25,7 @@ type UpdateSender struct { toSendMu sync.Mutex toSend map[string]*pathPfxs destroyCh chan struct{} + wg sync.WaitGroup } type pathPfxs struct { @@ -32,18 +33,16 @@ type pathPfxs struct { pfxs []bnet.Prefix } -func newUpdateSender(fsm *FSM, afi uint16, safi uint8) *UpdateSender { - f := fsm.addressFamily(afi, safi) - +func newUpdateSender(f *fsmAddressFamily) *UpdateSender { u := &UpdateSender{ - fsm: fsm, + fsm: f.fsm, addressFamily: f, - iBGP: fsm.peer.localASN == fsm.peer.peerASN, - rrClient: fsm.peer.routeReflectorClient, + iBGP: f.fsm.peer.localASN == f.fsm.peer.peerASN, + rrClient: f.fsm.peer.routeReflectorClient, destroyCh: make(chan struct{}), toSend: make(map[string]*pathPfxs), options: &packet.EncodeOptions{ - Use32BitASN: fsm.supports4OctetASN, + Use32BitASN: f.fsm.supports4OctetASN, UseAddPath: !f.addPathTX.BestOnly, }, } @@ -52,8 +51,14 @@ func newUpdateSender(fsm *FSM, afi uint16, safi uint8) *UpdateSender { return u } +// ClientCount is here to satisfy an interface +func (u *UpdateSender) ClientCount() uint64 { + return 0 +} + // Start starts the update sender func (u *UpdateSender) Start(aggrTime time.Duration) { + u.wg.Add(1) go u.sender(aggrTime) } @@ -99,6 +104,7 @@ func (u *UpdateSender) sender(aggrTime time.Duration) { for { select { case <-u.destroyCh: + u.wg.Done() return case <-ticker.C: } diff --git a/protocols/bgp/server/update_sender_test.go b/protocols/bgp/server/update_sender_test.go index 1d4d97fd4f2735283327341ade693344ecd484b4..c842621fdc9cd5b732dac0f5b88ed78f0c958c5c 100644 --- a/protocols/bgp/server/update_sender_test.go +++ b/protocols/bgp/server/update_sender_test.go @@ -902,7 +902,7 @@ func TestSender(t *testing.T) { fsmA.state = newEstablishedState(fsmA) fsmA.con = btest.NewMockConn() - updateSender := newUpdateSender(fsmA, test.afi, packet.UnicastSAFI) + updateSender := newUpdateSender(fsmA.addressFamily(test.afi, packet.UnicastSAFI)) for _, pathPfx := range test.paths { for _, pfx := range pathPfx.pfxs { diff --git a/protocols/fib/reader_linux.go b/protocols/fib/reader_linux.go index cb5653bffdae45f5fb1f1819ef6b2a871a007c65..c91e8e7cef4c9a08526d2c039088cf88c46de38e 100644 --- a/protocols/fib/reader_linux.go +++ b/protocols/fib/reader_linux.go @@ -42,6 +42,11 @@ func NewNetlinkReader(options *config.Netlink) *NetlinkReader { return nr } +// ClientCount is here to satisfy an interface +func (nr *NetlinkReader) ClientCount() uint64 { + return 0 +} + // Dump is here to fulfill an interface func (nr *NetlinkReader) Dump() []*route.Route { return nil diff --git a/routingtable/adjRIBIn/adj_rib_in.go b/routingtable/adjRIBIn/adj_rib_in.go index fd3675655c3e1f7e2c9f4e42d1eafb7061574a94..cd6784e1e7688f22e02235ece2920caa0fe2eb00 100644 --- a/routingtable/adjRIBIn/adj_rib_in.go +++ b/routingtable/adjRIBIn/adj_rib_in.go @@ -36,6 +36,11 @@ func New(exportFilter *filter.Filter, contributingASNs *routingtable.Contributin return a } +// ClientCount gets the number of registered clients +func (a *AdjRIBIn) ClientCount() uint64 { + return a.clientManager.ClientCount() +} + // Dump dumps the RIB func (a *AdjRIBIn) Dump() []*route.Route { a.mu.Lock() diff --git a/routingtable/adjRIBOut/adj_rib_out.go b/routingtable/adjRIBOut/adj_rib_out.go index b48f9654db1700812c7eb9137f5772f0659be038..9e33165f0b43c3432f66c800fc4f9ac9ae39f164 100644 --- a/routingtable/adjRIBOut/adj_rib_out.go +++ b/routingtable/adjRIBOut/adj_rib_out.go @@ -36,6 +36,11 @@ func New(neighbor *routingtable.Neighbor, exportFilter *filter.Filter, addPathTX return a } +// ClientCount gets the number of registered clients +func (a *AdjRIBOut) ClientCount() uint64 { + return a.clientManager.ClientCount() +} + // Dump dumps the RIB func (a *AdjRIBOut) Dump() []*route.Route { a.mu.RLock() diff --git a/routingtable/client_interface.go b/routingtable/client_interface.go index 7a78c080ca2fca662d90fa3c8834bcead2374eee..ecc93035d82718093548f49defe44d91f0808b2a 100644 --- a/routingtable/client_interface.go +++ b/routingtable/client_interface.go @@ -13,5 +13,6 @@ type RouteTableClient interface { Register(RouteTableClient) Unregister(RouteTableClient) RouteCount() int64 + ClientCount() uint64 Dump() []*route.Route } diff --git a/routingtable/client_manager.go b/routingtable/client_manager.go index 824f1151491aaf9d85535c8e82374f38a69aa7f2..5e4ac377ddeb872a606038c72d016cce35a12b95 100644 --- a/routingtable/client_manager.go +++ b/routingtable/client_manager.go @@ -39,6 +39,14 @@ func NewClientManager(master RouteTableClient) *ClientManager { } } +// ClientCount gets the number of registred clients +func (c *ClientManager) ClientCount() uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + + return uint64(len(c.clients)) +} + // GetOptions gets the options for a registered client func (c *ClientManager) GetOptions(client RouteTableClient) ClientOptions { c.mu.RLock() diff --git a/routingtable/client_manager_test.go b/routingtable/client_manager_test.go index 00a7ced2c148f1261d294753670f6ab882b7e575..fde4bc71add434a212eea9dbb4922630d00446cb 100644 --- a/routingtable/client_manager_test.go +++ b/routingtable/client_manager_test.go @@ -12,6 +12,10 @@ type MockClient struct { foo int } +func (m MockClient) ClientCount() uint64 { + return 0 +} + func (m MockClient) Dump() []*route.Route { return nil } diff --git a/routingtable/locRIB/loc_rib.go b/routingtable/locRIB/loc_rib.go index e4b6ec9dffa293b762fdb86be6daf46733acf9b1..b7a630ed516f6c63b8ce05c1544b778d3641be9a 100644 --- a/routingtable/locRIB/loc_rib.go +++ b/routingtable/locRIB/loc_rib.go @@ -38,6 +38,11 @@ func New(name string) *LocRIB { return a } +// ClientCount gets the number of registered clients +func (a *LocRIB) ClientCount() uint64 { + return a.clientManager.ClientCount() +} + // GetContributingASNs returns a pointer to the list of contributing ASNs func (a *LocRIB) GetContributingASNs() *routingtable.ContributingASNs { return a.contributingASNs diff --git a/routingtable/mock_client.go b/routingtable/mock_client.go index c009c4c3feb4c9a45f047adb8dd458bc1d24a4b2..25a26c400549c1458ec697720856a6898c457821 100644 --- a/routingtable/mock_client.go +++ b/routingtable/mock_client.go @@ -22,6 +22,10 @@ func NewRTMockClient() *RTMockClient { } } +func (m *RTMockClient) ClientCount() uint64 { + return 0 +} + func (m *RTMockClient) Removed() []*RemovePathParams { return m.removed }