diff --git a/routingtable/adjRIBIn/adj_rib_in.go b/routingtable/adjRIBIn/adj_rib_in.go index 446d71e02982a6bd903b9890c15e8501af6f19a1..011372794dd6493a4c7d8f0e73e7b482f7934825 100644 --- a/routingtable/adjRIBIn/adj_rib_in.go +++ b/routingtable/adjRIBIn/adj_rib_in.go @@ -22,25 +22,30 @@ func NewAdjRIBIn() *AdjRIBIn { // AddPath replaces the path for prefix `pfx`. If the prefix doesn't exist it is added. func (a *AdjRIBIn) AddPath(pfx net.Prefix, p *route.Path) error { oldPaths := a.rt.ReplacePath(pfx, p) - - for _, oldPath := range oldPaths { - for _, client := range a.ClientManager.Clients() { - client.RemovePath(pfx, oldPath) - } - } - + a.removePathsFromClients(pfx, oldPaths) return nil } // RemovePath removes the path for prefix `pfx` func (a *AdjRIBIn) RemovePath(pfx net.Prefix, p *route.Path) error { - if !a.rt.RemovePath(pfx, p) { + r := a.rt.Get(pfx) + if r == nil { return nil } - for _, client := range a.ClientManager.Clients() { - client.RemovePath(pfx, p) + oldPaths := r.Paths() + for _, path := range oldPaths { + a.rt.RemovePath(pfx, path) } + a.removePathsFromClients(pfx, oldPaths) return nil } + +func (a *AdjRIBIn) removePathsFromClients(pfx net.Prefix, paths []*route.Path) { + for _, path := range paths { + for _, client := range a.ClientManager.Clients() { + client.RemovePath(pfx, path) + } + } +} diff --git a/routingtable/adjRIBIn/adj_rib_in_test.go b/routingtable/adjRIBIn/adj_rib_in_test.go index a6c8989ee760a17ac1c2556533ade447cbcfdb45..4ae96263ba2171f17ac7681fbcccde3348de8205 100644 --- a/routingtable/adjRIBIn/adj_rib_in_test.go +++ b/routingtable/adjRIBIn/adj_rib_in_test.go @@ -200,12 +200,9 @@ func TestRemovePath(t *testing.T) { if mc.removePathParams.pfx != test.removePfx { t.Errorf("Test %q failed: Call to RemovePath did not propagate prefix properly: Got: %s Want: %s", test.name, mc.removePathParams.pfx.String(), test.removePfx.String()) } - - if mc.removePathParams.path != test.removePath { - t.Errorf("Test %q failed: Call to RemovePath did not propagate path properly: Got: %v Want: %v", test.name, mc.removePathParams.path, test.removePath) - } + assert.Equal(t, test.removePath, mc.removePathParams.path) } else { - if mc.removePathParams.pfx != net.NewPfx(0, 0) || mc.removePathParams.path != nil { + if mc.removePathParams.pfx != net.NewPfx(0, 0) { t.Errorf("Test %q failed: Call to RemovePath propagated unexpectedly", test.name) } }