diff --git a/protocols/isis/server/device_manager_test.go b/protocols/isis/server/device_manager_test.go index bf5c580299a8a8ea0cbbfeef05231acb02e9e2ae..f709d24e4ec77667384ebb1253f1ac15899e74e7 100644 --- a/protocols/isis/server/device_manager_test.go +++ b/protocols/isis/server/device_manager_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/bio-routing/bio-rd/config" - "github.com/bio-routing/bio-rd/net" "github.com/bio-routing/bio-rd/protocols/device" "github.com/stretchr/testify/assert" ) @@ -246,57 +245,3 @@ func TestDeviceAddDevice(t *testing.T) { assert.Equal(t, test.expected, test.dm, test.name) } } - -func TestValidateNeighborAddresses(t *testing.T) { - tests := []struct { - name string - d *dev - addrs []uint32 - expected []uint32 - }{ - { - name: "Test #1", - d: &dev{ - phy: &device.Device{ - Addrs: []net.Prefix{ - net.NewPfx(net.IPv4FromOctets(10, 0, 0, 0), 24), - }, - }, - }, - addrs: []uint32{ - net.IPv4FromOctets(10, 0, 0, 2).ToUint32(), - }, - expected: []uint32{ - net.IPv4FromOctets(10, 0, 0, 2).ToUint32(), - }, - }, - { - name: "Test #2", - d: &dev{ - phy: &device.Device{ - Addrs: []net.Prefix{ - net.NewPfx(net.IPv4FromOctets(10, 0, 0, 0), 30), - net.NewPfx(net.IPv4FromOctets(10, 0, 0, 4), 30), - net.NewPfx(net.IPv4FromOctets(192, 168, 100, 0), 22), - }, - }, - }, - addrs: []uint32{ - net.IPv4FromOctets(100, 100, 100, 100).ToUint32(), - net.IPv4FromOctets(10, 0, 0, 5).ToUint32(), - net.IPv4FromOctets(10, 0, 0, 9).ToUint32(), - net.IPv4FromOctets(192, 168, 101, 22).ToUint32(), - net.IPv4FromOctets(10, 0, 0, 22).ToUint32(), - }, - expected: []uint32{ - net.IPv4FromOctets(10, 0, 0, 5).ToUint32(), - net.IPv4FromOctets(192, 168, 101, 22).ToUint32(), - }, - }, - } - - for _, test := range tests { - res := test.d.validateNeighborAddresses(test.addrs) - assert.Equal(t, test.expected, res, test.name) - } -} diff --git a/protocols/isis/server/device_test.go b/protocols/isis/server/device_test.go index 34de65828530d207e8a1f75ffe896f619564e8f5..fa74045577735743afd5a92d6db1cbe8a383c19a 100644 --- a/protocols/isis/server/device_test.go +++ b/protocols/isis/server/device_test.go @@ -3,6 +3,7 @@ package server import ( "testing" + "github.com/bio-routing/bio-rd/net" "github.com/bio-routing/bio-rd/protocols/device" "github.com/stretchr/testify/assert" ) @@ -124,3 +125,57 @@ func TestDeviceUpdate(t *testing.T) { assert.Equal(t, test.expected, test.dev.up, test.name) } } + +func TestValidateNeighborAddresses(t *testing.T) { + tests := []struct { + name string + d *dev + addrs []uint32 + expected []uint32 + }{ + { + name: "Test #1", + d: &dev{ + phy: &device.Device{ + Addrs: []net.Prefix{ + net.NewPfx(net.IPv4FromOctets(10, 0, 0, 0), 24), + }, + }, + }, + addrs: []uint32{ + net.IPv4FromOctets(10, 0, 0, 2).ToUint32(), + }, + expected: []uint32{ + net.IPv4FromOctets(10, 0, 0, 2).ToUint32(), + }, + }, + { + name: "Test #2", + d: &dev{ + phy: &device.Device{ + Addrs: []net.Prefix{ + net.NewPfx(net.IPv4FromOctets(10, 0, 0, 0), 30), + net.NewPfx(net.IPv4FromOctets(10, 0, 0, 4), 30), + net.NewPfx(net.IPv4FromOctets(192, 168, 100, 0), 22), + }, + }, + }, + addrs: []uint32{ + net.IPv4FromOctets(100, 100, 100, 100).ToUint32(), + net.IPv4FromOctets(10, 0, 0, 5).ToUint32(), + net.IPv4FromOctets(10, 0, 0, 9).ToUint32(), + net.IPv4FromOctets(192, 168, 101, 22).ToUint32(), + net.IPv4FromOctets(10, 0, 0, 22).ToUint32(), + }, + expected: []uint32{ + net.IPv4FromOctets(10, 0, 0, 5).ToUint32(), + net.IPv4FromOctets(192, 168, 101, 22).ToUint32(), + }, + }, + } + + for _, test := range tests { + res := test.d.validateNeighborAddresses(test.addrs) + assert.Equal(t, test.expected, res, test.name) + } +} diff --git a/protocols/isis/server/neighbor.go b/protocols/isis/server/neighbor.go index 9fab886bad8d3c965e9ebae3b17abf7b1d268808..a055849bb00a1aad00f51bc6d8a2ed5af034a514 100644 --- a/protocols/isis/server/neighbor.go +++ b/protocols/isis/server/neighbor.go @@ -53,11 +53,11 @@ func (n *neighbor) hello(h *neighbor) (dispose bool) { return true } + n.holdingTime = h.holdingTime if !n.holdingTimer.Reset(time.Duration(n.holdingTime)) { n.dispose(fmt.Errorf("Hold timer expired")) return true } - n.holdingTime = n.holdingTime if !n.ipAddrsEqual(validAddrs) || n.localCircuitID != h.localCircuitID || n.extendedLocalCircuitID != h.extendedLocalCircuitID { n.dev.srv.lsdb.triggerLSPDUGen()