From 604d7ab5fc517e2c9ddfe7a895d9d8e2bf5aeb7e Mon Sep 17 00:00:00 2001
From: Malte Bauch <malte.bauch@tbnet.works>
Date: Mon, 5 Sep 2022 14:41:48 +0000
Subject: [PATCH] Resolve "Restarting the controller after devices are
 registered is throwing a panic"

See merge request danet/gosdn!373

Co-authored-by: Malte Bauch <malte.bauch@extern.h-da.de>
---
 controller/interfaces/device/device.go        |  14 ---
 controller/interfaces/device/deviceService.go |   3 +-
 controller/interfaces/networkdomain/pnd.go    |   3 +-
 controller/mocks/NetworkDomain.go             |  18 ++-
 controller/mocks/Service.go                   | 104 ++----------------
 controller/mocks/Store.go                     |  77 +++++++------
 controller/northbound/server/device.go        |   4 -
 controller/northbound/server/pnd.go           |  16 +--
 controller/nucleus/deviceService.go           |  26 ++++-
 controller/nucleus/deviceServiceMock.go       |  13 ++-
 controller/nucleus/deviceWatcher.go           |   3 +-
 controller/nucleus/principalNetworkDomain.go  |  12 +-
 .../nucleus/principalNetworkDomain_test.go    |   2 +-
 13 files changed, 125 insertions(+), 170 deletions(-)

diff --git a/controller/interfaces/device/device.go b/controller/interfaces/device/device.go
index 4becb471c..c09b30be6 100644
--- a/controller/interfaces/device/device.go
+++ b/controller/interfaces/device/device.go
@@ -51,18 +51,4 @@ type LoadedDevice struct {
 	// SBI indicates the southbound interface, which is used by this device as UUID.
 	SBI   string `json:"sbi"`
 	Model string `json:"model,omitempty" bson:"model,omitempty"`
-
-	convertFunc func(LoadedDevice) (Device, error)
-}
-
-// SetConvertFunction allows to set the LoadedDevice's convert function. This
-// function should take a LoadedDevice and returns a Device.
-func (ld *LoadedDevice) SetConvertFunction(cf func(LoadedDevice) (Device, error)) {
-	ld.convertFunc = cf
-}
-
-// ConvertToDevice calls the LoadedDevice's convert function and converts the
-// LoadedDevice into a Device.
-func (ld LoadedDevice) ConvertToDevice() (Device, error) {
-	return ld.convertFunc(ld)
 }
diff --git a/controller/interfaces/device/deviceService.go b/controller/interfaces/device/deviceService.go
index a9eb3990e..525678bd9 100644
--- a/controller/interfaces/device/deviceService.go
+++ b/controller/interfaces/device/deviceService.go
@@ -11,5 +11,6 @@ type Service interface {
 	UpdateModel(Device, string) error
 	Delete(Device) error
 	Get(store.Query) (Device, error)
-	GetAll() ([]LoadedDevice, error)
+	GetAll() ([]Device, error)
+	GetAllAsLoaded() ([]LoadedDevice, error)
 }
diff --git a/controller/interfaces/networkdomain/pnd.go b/controller/interfaces/networkdomain/pnd.go
index 082b8850d..12aff19d3 100644
--- a/controller/interfaces/networkdomain/pnd.go
+++ b/controller/interfaces/networkdomain/pnd.go
@@ -20,7 +20,8 @@ type NetworkDomain interface {
 	GetDevice(identifier string) (device.Device, error)
 	RemoveDevice(uuid.UUID) error
 	UpdateDevice(device.Device, string) error
-	Devices() []device.LoadedDevice
+	Devices() []device.Device
+	FlattenedDevices() []device.LoadedDevice
 	ChangeOND(uuid uuid.UUID, operation ppb.ApiOperation, path string, value ...string) (uuid.UUID, error)
 	Request(uuid.UUID, string) (proto.Message, error)
 	RequestAll(string) error
diff --git a/controller/mocks/NetworkDomain.go b/controller/mocks/NetworkDomain.go
index 22b51b137..6a770ba4b 100644
--- a/controller/mocks/NetworkDomain.go
+++ b/controller/mocks/NetworkDomain.go
@@ -150,7 +150,23 @@ func (_m *NetworkDomain) Destroy() error {
 }
 
 // Devices provides a mock function with given fields:
-func (_m *NetworkDomain) Devices() []device.LoadedDevice {
+func (_m *NetworkDomain) Devices() []device.Device {
+	ret := _m.Called()
+
+	var r0 []device.Device
+	if rf, ok := ret.Get(0).(func() []device.Device); ok {
+		r0 = rf()
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]device.Device)
+		}
+	}
+
+	return r0
+}
+
+// FlattenedDevices provides a mock function with given fields:
+func (_m *NetworkDomain) FlattenedDevices() []device.LoadedDevice {
 	ret := _m.Called()
 
 	var r0 []device.LoadedDevice
diff --git a/controller/mocks/Service.go b/controller/mocks/Service.go
index 188cd0c9c..55386e172 100644
--- a/controller/mocks/Service.go
+++ b/controller/mocks/Service.go
@@ -3,10 +3,9 @@
 package mocks
 
 import (
-	device "code.fbi.h-da.de/danet/gosdn/controller/interfaces/device"
-	mock "github.com/stretchr/testify/mock"
+	controllerevent "code.fbi.h-da.de/danet/gosdn/controller/event"
 
-	store "code.fbi.h-da.de/danet/gosdn/controller/store"
+	mock "github.com/stretchr/testify/mock"
 )
 
 // Service is an autogenerated mock type for the Service type
@@ -14,101 +13,18 @@ type Service struct {
 	mock.Mock
 }
 
-// Add provides a mock function with given fields: _a0
-func (_m *Service) Add(_a0 device.Device) error {
-	ret := _m.Called(_a0)
-
-	var r0 error
-	if rf, ok := ret.Get(0).(func(device.Device) error); ok {
-		r0 = rf(_a0)
-	} else {
-		r0 = ret.Error(0)
-	}
-
-	return r0
-}
-
-// Delete provides a mock function with given fields: _a0
-func (_m *Service) Delete(_a0 device.Device) error {
-	ret := _m.Called(_a0)
-
-	var r0 error
-	if rf, ok := ret.Get(0).(func(device.Device) error); ok {
-		r0 = rf(_a0)
-	} else {
-		r0 = ret.Error(0)
-	}
-
-	return r0
-}
-
-// Get provides a mock function with given fields: _a0
-func (_m *Service) Get(_a0 store.Query) (device.Device, error) {
-	ret := _m.Called(_a0)
-
-	var r0 device.Device
-	if rf, ok := ret.Get(0).(func(store.Query) device.Device); ok {
-		r0 = rf(_a0)
-	} else {
-		if ret.Get(0) != nil {
-			r0 = ret.Get(0).(device.Device)
-		}
-	}
-
-	var r1 error
-	if rf, ok := ret.Get(1).(func(store.Query) error); ok {
-		r1 = rf(_a0)
-	} else {
-		r1 = ret.Error(1)
-	}
-
-	return r0, r1
-}
-
-// GetAll provides a mock function with given fields:
-func (_m *Service) GetAll() ([]device.Device, error) {
-	ret := _m.Called()
-
-	var r0 []device.Device
-	if rf, ok := ret.Get(0).(func() []device.Device); ok {
-		r0 = rf()
-	} else {
-		if ret.Get(0) != nil {
-			r0 = ret.Get(0).([]device.Device)
-		}
-	}
-
-	var r1 error
-	if rf, ok := ret.Get(1).(func() error); ok {
-		r1 = rf()
-	} else {
-		r1 = ret.Error(1)
-	}
-
-	return r0, r1
-}
-
-// Update provides a mock function with given fields: _a0
-func (_m *Service) Update(_a0 device.Device) error {
-	ret := _m.Called(_a0)
-
-	var r0 error
-	if rf, ok := ret.Get(0).(func(device.Device) error); ok {
-		r0 = rf(_a0)
-	} else {
-		r0 = ret.Error(0)
-	}
-
-	return r0
+// CloseConnection provides a mock function with given fields:
+func (_m *Service) CloseConnection() {
+	_m.Called()
 }
 
-// UpdateModel provides a mock function with given fields: _a0, _a1
-func (_m *Service) UpdateModel(_a0 device.Device, _a1 string) error {
-	ret := _m.Called(_a0, _a1)
+// PublishEvent provides a mock function with given fields: topic, _a1
+func (_m *Service) PublishEvent(topic string, _a1 controllerevent.Event) error {
+	ret := _m.Called(topic, _a1)
 
 	var r0 error
-	if rf, ok := ret.Get(0).(func(device.Device, string) error); ok {
-		r0 = rf(_a0, _a1)
+	if rf, ok := ret.Get(0).(func(string, controllerevent.Event) error); ok {
+		r0 = rf(topic, _a1)
 	} else {
 		r0 = ret.Error(0)
 	}
diff --git a/controller/mocks/Store.go b/controller/mocks/Store.go
index f8f71cd7d..6ce0fd985 100644
--- a/controller/mocks/Store.go
+++ b/controller/mocks/Store.go
@@ -3,8 +3,8 @@
 package mocks
 
 import (
-	southbound "code.fbi.h-da.de/danet/gosdn/controller/interfaces/southbound"
-	store "code.fbi.h-da.de/danet/gosdn/controller/store"
+	store "code.fbi.h-da.de/danet/gosdn/controller/interfaces/store"
+	uuid "github.com/google/uuid"
 	mock "github.com/stretchr/testify/mock"
 )
 
@@ -13,13 +13,13 @@ type Store struct {
 	mock.Mock
 }
 
-// Add provides a mock function with given fields: _a0
-func (_m *Store) Add(_a0 southbound.SouthboundInterface) error {
-	ret := _m.Called(_a0)
+// Add provides a mock function with given fields: item
+func (_m *Store) Add(item store.Storable) error {
+	ret := _m.Called(item)
 
 	var r0 error
-	if rf, ok := ret.Get(0).(func(southbound.SouthboundInterface) error); ok {
-		r0 = rf(_a0)
+	if rf, ok := ret.Get(0).(func(store.Storable) error); ok {
+		r0 = rf(item)
 	} else {
 		r0 = ret.Error(0)
 	}
@@ -27,13 +27,13 @@ func (_m *Store) Add(_a0 southbound.SouthboundInterface) error {
 	return r0
 }
 
-// Delete provides a mock function with given fields: _a0
-func (_m *Store) Delete(_a0 southbound.SouthboundInterface) error {
-	ret := _m.Called(_a0)
+// Delete provides a mock function with given fields: id
+func (_m *Store) Delete(id uuid.UUID) error {
+	ret := _m.Called(id)
 
 	var r0 error
-	if rf, ok := ret.Get(0).(func(southbound.SouthboundInterface) error); ok {
-		r0 = rf(_a0)
+	if rf, ok := ret.Get(0).(func(uuid.UUID) error); ok {
+		r0 = rf(id)
 	} else {
 		r0 = ret.Error(0)
 	}
@@ -41,20 +41,36 @@ func (_m *Store) Delete(_a0 southbound.SouthboundInterface) error {
 	return r0
 }
 
-// Get provides a mock function with given fields: _a0
-func (_m *Store) Get(_a0 store.Query) (southbound.LoadedSbi, error) {
-	ret := _m.Called(_a0)
+// Exists provides a mock function with given fields: id
+func (_m *Store) Exists(id uuid.UUID) bool {
+	ret := _m.Called(id)
 
-	var r0 southbound.LoadedSbi
-	if rf, ok := ret.Get(0).(func(store.Query) southbound.LoadedSbi); ok {
-		r0 = rf(_a0)
+	var r0 bool
+	if rf, ok := ret.Get(0).(func(uuid.UUID) bool); ok {
+		r0 = rf(id)
 	} else {
-		r0 = ret.Get(0).(southbound.LoadedSbi)
+		r0 = ret.Get(0).(bool)
+	}
+
+	return r0
+}
+
+// Get provides a mock function with given fields: id
+func (_m *Store) Get(id uuid.UUID) (store.Storable, error) {
+	ret := _m.Called(id)
+
+	var r0 store.Storable
+	if rf, ok := ret.Get(0).(func(uuid.UUID) store.Storable); ok {
+		r0 = rf(id)
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).(store.Storable)
+		}
 	}
 
 	var r1 error
-	if rf, ok := ret.Get(1).(func(store.Query) error); ok {
-		r1 = rf(_a0)
+	if rf, ok := ret.Get(1).(func(uuid.UUID) error); ok {
+		r1 = rf(id)
 	} else {
 		r1 = ret.Error(1)
 	}
@@ -62,27 +78,20 @@ func (_m *Store) Get(_a0 store.Query) (southbound.LoadedSbi, error) {
 	return r0, r1
 }
 
-// GetAll provides a mock function with given fields:
-func (_m *Store) GetAll() ([]southbound.LoadedSbi, error) {
+// UUIDs provides a mock function with given fields:
+func (_m *Store) UUIDs() []uuid.UUID {
 	ret := _m.Called()
 
-	var r0 []southbound.LoadedSbi
-	if rf, ok := ret.Get(0).(func() []southbound.LoadedSbi); ok {
+	var r0 []uuid.UUID
+	if rf, ok := ret.Get(0).(func() []uuid.UUID); ok {
 		r0 = rf()
 	} else {
 		if ret.Get(0) != nil {
-			r0 = ret.Get(0).([]southbound.LoadedSbi)
+			r0 = ret.Get(0).([]uuid.UUID)
 		}
 	}
 
-	var r1 error
-	if rf, ok := ret.Get(1).(func() error); ok {
-		r1 = rf()
-	} else {
-		r1 = ret.Error(1)
-	}
-
-	return r0, r1
+	return r0
 }
 
 type mockConstructorTestingTNewStore interface {
diff --git a/controller/northbound/server/device.go b/controller/northbound/server/device.go
index c5c07936a..e034192fd 100644
--- a/controller/northbound/server/device.go
+++ b/controller/northbound/server/device.go
@@ -54,10 +54,6 @@ func (d *DeviceServer) GetAll(ctx context.Context, request *dpb.GetAllDeviceRequ
 
 	onds := []*dpb.Device{}
 	for _, device := range devices {
-		device, err := device.ConvertToDevice()
-		if err != nil {
-			return nil, status.Errorf(codes.Aborted, "%v", err)
-		}
 		ygotStructAsJSON, err := device.GetModelAsString()
 		if err != nil {
 			log.Error(err)
diff --git a/controller/northbound/server/pnd.go b/controller/northbound/server/pnd.go
index dc8bd2f87..8020c8b5e 100644
--- a/controller/northbound/server/pnd.go
+++ b/controller/northbound/server/pnd.go
@@ -58,7 +58,7 @@ func (p PndServer) GetOnd(ctx context.Context, request *ppb.GetOndRequest) (*ppb
 		return nil, status.Errorf(codes.Aborted, "%v", err)
 	}
 
-	ond, err := fillOndBySpecificPath(pnd, device, "/")
+	ond, err := fillOndBySpecificPath(device, "/")
 	if err != nil {
 		log.Error(err)
 		return nil, status.Errorf(codes.Aborted, "%v", err)
@@ -92,12 +92,8 @@ func (p PndServer) GetOndList(ctx context.Context, request *ppb.GetOndListReques
 	}
 
 	onds := make([]*ppb.OrchestratedNetworkingDevice, len(pnd.Devices()))
-	for i, loadedDevice := range pnd.Devices() {
-		device, err := loadedDevice.ConvertToDevice()
-		if err != nil {
-			return nil, status.Errorf(codes.Aborted, "%v", err)
-		}
-		ond, err := fillOndBySpecificPath(pnd, device, "/")
+	for i, device := range pnd.Devices() {
+		ond, err := fillOndBySpecificPath(device, "/")
 		if err != nil {
 			log.Error(err)
 			return nil, status.Errorf(codes.Aborted, "%v", err)
@@ -132,7 +128,7 @@ func (p PndServer) GetFlattenedOndList(ctx context.Context, request *ppb.GetOndL
 		return nil, status.Errorf(codes.Aborted, "%v", err)
 	}
 
-	onds := pnd.Devices()
+	onds := pnd.FlattenedDevices()
 	flattenedOnds := make([]*ppb.FlattenedOrchestratedNetworkingDevice, len(onds))
 	for i, ond := range onds {
 		ond := &ppb.FlattenedOrchestratedNetworkingDevice{
@@ -155,7 +151,7 @@ func (p PndServer) GetFlattenedOndList(ctx context.Context, request *ppb.GetOndL
 	}, nil
 }
 
-func fillOndBySpecificPath(pnd networkdomain.NetworkDomain, d device.Device, path string) (*ppb.OrchestratedNetworkingDevice, error) {
+func fillOndBySpecificPath(d device.Device, path string) (*ppb.OrchestratedNetworkingDevice, error) {
 	gnmiPath, err := ygot.StringToStructuredPath(path)
 	if err != nil {
 		log.Error(err)
@@ -354,7 +350,7 @@ func (p PndServer) GetPath(ctx context.Context, request *ppb.GetPathRequest) (*p
 		return nil, status.Errorf(codes.Aborted, "%v", err)
 	}
 
-	ond, err := fillOndBySpecificPath(pnd, device, path)
+	ond, err := fillOndBySpecificPath(device, path)
 	if err != nil {
 		log.Error(err)
 		return nil, status.Errorf(codes.Aborted, "%v", err)
diff --git a/controller/nucleus/deviceService.go b/controller/nucleus/deviceService.go
index 7e2e00c8a..9b3d47252 100644
--- a/controller/nucleus/deviceService.go
+++ b/controller/nucleus/deviceService.go
@@ -49,8 +49,6 @@ func (s *DeviceService) Get(query store.Query) (device.Device, error) {
 		return nil, err
 	}
 
-	loadedDevice.SetConvertFunction(s.createDeviceFromStore)
-
 	device, err := s.createDeviceFromStore(loadedDevice)
 	if err != nil {
 		return nil, err
@@ -60,14 +58,34 @@ func (s *DeviceService) Get(query store.Query) (device.Device, error) {
 }
 
 // GetAll returns all stored devices.
-func (s *DeviceService) GetAll() ([]device.LoadedDevice, error) {
+func (s *DeviceService) GetAll() ([]device.Device, error) {
+	var devices []device.Device
+
 	loadedDevices, err := s.deviceStore.GetAll()
 	if err != nil {
 		return nil, err
 	}
 
 	for _, loadedDevice := range loadedDevices {
-		loadedDevice.SetConvertFunction(s.createDeviceFromStore)
+		device, err := s.createDeviceFromStore(loadedDevice)
+		if err != nil {
+			return nil, err
+		}
+
+		devices = append(devices, device)
+	}
+
+	return devices, nil
+}
+
+// GetAllAsLoaded returns all stored devices as LoadedDevice.
+// This method should be used if there is no need for a device.Device, since
+// requesting device information through this method is a lot faster than the
+// usual `GetAll` method.
+func (s *DeviceService) GetAllAsLoaded() ([]device.LoadedDevice, error) {
+	loadedDevices, err := s.deviceStore.GetAll()
+	if err != nil {
+		return nil, err
 	}
 
 	return loadedDevices, nil
diff --git a/controller/nucleus/deviceServiceMock.go b/controller/nucleus/deviceServiceMock.go
index e90ec1c49..a738d5ae8 100644
--- a/controller/nucleus/deviceServiceMock.go
+++ b/controller/nucleus/deviceServiceMock.go
@@ -89,7 +89,18 @@ func (t *DeviceServiceMock) Get(query store.Query) (device.Device, error) {
 }
 
 // GetAll gets all items.
-func (t *DeviceServiceMock) GetAll() ([]device.LoadedDevice, error) {
+func (t *DeviceServiceMock) GetAll() ([]device.Device, error) {
+	var allItems []device.Device
+
+	for _, item := range t.Store {
+		allItems = append(allItems, item)
+	}
+
+	return allItems, nil
+}
+
+// GetAllAsLoaded gets all items as `device.LoadedDevice`.
+func (t *DeviceServiceMock) GetAllAsLoaded() ([]device.LoadedDevice, error) {
 	var allItems []device.LoadedDevice
 
 	for _, item := range t.Store {
diff --git a/controller/nucleus/deviceWatcher.go b/controller/nucleus/deviceWatcher.go
index b1738331b..61e3f0897 100644
--- a/controller/nucleus/deviceWatcher.go
+++ b/controller/nucleus/deviceWatcher.go
@@ -68,8 +68,7 @@ func (d *DeviceWatcher) SubToDevices(paths [][]string, opts *gnmi.SubscribeOptio
 }
 
 func (d *DeviceWatcher) subscribeToPndDevices(pndID string, pnd networkdomain.NetworkDomain, opts *gnmi.SubscribeOptions) {
-	for _, loadedDevice := range pnd.Devices() {
-		device, _ := loadedDevice.ConvertToDevice()
+	for _, device := range pnd.Devices() {
 		subID := uuid.New()
 
 		stopContext, cancel := context.WithCancel(context.Background())
diff --git a/controller/nucleus/principalNetworkDomain.go b/controller/nucleus/principalNetworkDomain.go
index 86d4f5ee9..fe4b0f9ac 100644
--- a/controller/nucleus/principalNetworkDomain.go
+++ b/controller/nucleus/principalNetworkDomain.go
@@ -157,12 +157,18 @@ func (pnd *pndImplementation) ID() uuid.UUID {
 	return pnd.Id
 }
 
-func (pnd *pndImplementation) Devices() []device.LoadedDevice {
+func (pnd *pndImplementation) Devices() []device.Device {
 	allDevices, _ := pnd.deviceService.GetAll()
 
 	return allDevices
 }
 
+func (pnd *pndImplementation) FlattenedDevices() []device.LoadedDevice {
+	allDevices, _ := pnd.deviceService.GetAllAsLoaded()
+
+	return allDevices
+}
+
 // GetName returns the name of the PND.
 func (pnd *pndImplementation) GetName() string {
 	return pnd.Name
@@ -206,7 +212,7 @@ func (pnd *pndImplementation) AddSbi(s southbound.SouthboundInterface) error {
 func (pnd *pndImplementation) RemoveSbi(sid uuid.UUID) error {
 	var associatedDevices []device.LoadedDevice
 
-	allExistingDevices, err := pnd.deviceService.GetAll()
+	allExistingDevices, err := pnd.deviceService.GetAllAsLoaded()
 	if err != nil {
 		return err
 	}
@@ -436,7 +442,7 @@ func (pnd *pndImplementation) Request(uuid uuid.UUID, path string) (proto.Messag
 
 // RequestAll sends a request for all registered devices.
 func (pnd *pndImplementation) RequestAll(path string) error {
-	allDevices, err := pnd.deviceService.GetAll()
+	allDevices, err := pnd.deviceService.GetAllAsLoaded()
 	if err != nil {
 		return err
 	}
diff --git a/controller/nucleus/principalNetworkDomain_test.go b/controller/nucleus/principalNetworkDomain_test.go
index 8a5d69611..7468c7b92 100644
--- a/controller/nucleus/principalNetworkDomain_test.go
+++ b/controller/nucleus/principalNetworkDomain_test.go
@@ -785,7 +785,7 @@ func Test_pndImplementation_ChangeOND(t *testing.T) {
 				return
 			}
 
-			devices, err := pnd.deviceService.GetAll()
+			devices, err := pnd.deviceService.GetAllAsLoaded()
 			if err != nil {
 				err := errors.New("error fetching device")
 				t.Error(err)
-- 
GitLab