From 5011a5ff8b9662169d2940a2e154d67c2d3028e7 Mon Sep 17 00:00:00 2001
From: Malte Bauch <malte.bauch@stud.h-da.de>
Date: Thu, 19 Oct 2023 09:37:03 +0000
Subject: [PATCH] Resolve "Plugin is created and persisted in the database even
 if the creation of a network element failed"

See merge request danet/gosdn!544
---
 controller/controller.go                      |  2 +-
 controller/interfaces/plugin/plugin.go        |  1 +
 controller/mocks/Plugin.go                    | 14 +++++
 .../server/configurationmanagement.go         |  4 ++
 .../northbound/server/networkElement.go       | 61 ++++++++++++-------
 controller/nucleus/plugin.go                  | 13 +++-
 controller/nucleus/pluginService.go           |  6 +-
 7 files changed, 71 insertions(+), 30 deletions(-)

diff --git a/controller/controller.go b/controller/controller.go
index d91183317..d16fd4bf1 100644
--- a/controller/controller.go
+++ b/controller/controller.go
@@ -392,7 +392,7 @@ func shutdown() error {
 		}
 		for _, plugin := range plugins {
 			log.Info("Defer: ", plugin.Manifest().Name)
-			plugin.GetClient().Kill()
+			plugin.Close()
 			log.Info("Defer - exited: ", plugin.GetClient().Exited())
 		}
 		coreLock.Unlock()
diff --git a/controller/interfaces/plugin/plugin.go b/controller/interfaces/plugin/plugin.go
index f61bda8a8..e51743309 100644
--- a/controller/interfaces/plugin/plugin.go
+++ b/controller/interfaces/plugin/plugin.go
@@ -46,6 +46,7 @@ type Plugin interface {
 	Ping() error
 	Restart() error
 	Close()
+	Remove() error
 	shared.DeviceModel
 }
 
diff --git a/controller/mocks/Plugin.go b/controller/mocks/Plugin.go
index 9af1bf4b7..462e11d89 100644
--- a/controller/mocks/Plugin.go
+++ b/controller/mocks/Plugin.go
@@ -219,6 +219,20 @@ func (_m *Plugin) PruneConfigFalse(value []byte) ([]byte, error) {
 	return r0, r1
 }
 
+// Remove provides a mock function with given fields:
+func (_m *Plugin) Remove() error {
+	ret := _m.Called()
+
+	var r0 error
+	if rf, ok := ret.Get(0).(func() error); ok {
+		r0 = rf()
+	} else {
+		r0 = ret.Error(0)
+	}
+
+	return r0
+}
+
 // Restart provides a mock function with given fields:
 func (_m *Plugin) Restart() error {
 	ret := _m.Called()
diff --git a/controller/northbound/server/configurationmanagement.go b/controller/northbound/server/configurationmanagement.go
index 7dd74e3cc..ebe8436db 100644
--- a/controller/northbound/server/configurationmanagement.go
+++ b/controller/northbound/server/configurationmanagement.go
@@ -337,6 +337,10 @@ func (c ConfigurationManagementServer) createNetworkElements(sdnConfig *loadedSD
 			return err
 		}
 
+		if err := c.pluginService.Add(plugin); err != nil {
+			return err
+		}
+
 		err = c.mneService.UpdateModel(createdNetworkElement.ID(), inputNetworkElement.Model)
 		if err != nil {
 			return err
diff --git a/controller/northbound/server/networkElement.go b/controller/northbound/server/networkElement.go
index 1528f7514..89407828b 100644
--- a/controller/northbound/server/networkElement.go
+++ b/controller/northbound/server/networkElement.go
@@ -707,25 +707,6 @@ func (n *NetworkElementServer) SetMneList(ctx context.Context, request *mnepb.Se
 func (n *NetworkElementServer) addMne(ctx context.Context, name string, opt *tpb.TransportOption, requestPluginFunc func(uuid.UUID) (plugin.Plugin, error), pluginID uuid.UUID, pndID uuid.UUID, optionalNetworkElementID ...uuid.UUID) (uuid.UUID, error) {
 	var err error
 
-	// Note: cSBI not supported currently, so this is commented fow now.
-	// Might be needed or removed in the future.
-	//
-	// switch t := opt.Type; t {
-	// case spb.Type_TYPE_CONTAINERISED:
-	// 	return n.handleCsbiEnrolment(name, opt)
-	// case spb.Type_TYPE_PLUGIN:
-	// 	sbi, err = n.requestPlugin(name, opt)
-	// 	if err != nil {
-	// 		return uuid.Nil, err
-	// 	}
-	// default:
-	// 	var err error
-	// 	sbi, err = pnd.southboundService.Get(store.Query{ID: sid})
-	// 	if err != nil {
-	// 		return uuid.Nil, err
-	// 	}
-	// }
-
 	networkElementID := uuid.Nil
 	if len(optionalNetworkElementID) > 0 {
 		networkElementID = optionalNetworkElementID[0]
@@ -742,33 +723,67 @@ func (n *NetworkElementServer) addMne(ctx context.Context, name string, opt *tpb
 
 	mne, err := nucleus.NewNetworkElement(name, networkElementID, opt, pndID, plugin, conflict.Metadata{ResourceVersion: 0})
 	if err != nil {
+		if pluginRmErr := plugin.Remove(); err != nil {
+			return uuid.Nil, pluginRmErr
+		}
 		return uuid.Nil, err
 	}
 
 	if mne.IsTransportValid() {
-		resp, err := n.getPath(ctx, mne, "/")
+		err := n.initialNetworkElementRootPathRequest(ctx, mne, plugin)
 		if err != nil {
 			return uuid.Nil, err
 		}
 
-		err = mne.ProcessResponse(resp)
+		err = n.mneService.Add(mne)
 		if err != nil {
+			if pluginRmErr := plugin.Remove(); err != nil {
+				return uuid.Nil, pluginRmErr
+			}
 			return uuid.Nil, err
 		}
 
-		err = n.mneService.Add(mne)
+		err = n.pluginService.Add(plugin)
 		if err != nil {
+			if pluginRmErr := plugin.Remove(); err != nil {
+				return uuid.Nil, pluginRmErr
+			}
+			if err := n.mneService.Delete(mne); err != nil {
+				return uuid.Nil, err
+			}
 			return uuid.Nil, err
 		}
 
 		n.networkElementWatchter.SubscribeToNetworkElement(mne, config.GetGnmiSubscriptionPaths(), nil)
 	} else {
-		return uuid.Nil, status.Errorf(codes.InvalidArgument, "invalid transport data provided")
+		if pluginRmErr := plugin.Remove(); err != nil {
+			return uuid.Nil, pluginRmErr
+		}
+		return uuid.Nil, fmt.Errorf("invalid transport data provided")
 	}
 
 	return mne.ID(), nil
 }
 
+func (n *NetworkElementServer) initialNetworkElementRootPathRequest(ctx context.Context, mne networkelement.NetworkElement, plugin plugin.Plugin) error {
+	resp, err := n.getPath(ctx, mne, "/")
+	if err != nil {
+		if pluginRmErr := plugin.Remove(); err != nil {
+			return pluginRmErr
+		}
+		return err
+	}
+
+	err = mne.ProcessResponse(resp)
+	if err != nil {
+		if pluginRmErr := plugin.Remove(); err != nil {
+			return pluginRmErr
+		}
+		return err
+	}
+	return nil
+}
+
 // SetChangeList sets a list of changes.
 func (n *NetworkElementServer) SetChangeList(ctx context.Context, request *mnepb.SetChangeListRequest) (*mnepb.SetChangeListResponse, error) {
 	labels := prometheus.Labels{"service": "mne", "rpc": "set"}
diff --git a/controller/nucleus/plugin.go b/controller/nucleus/plugin.go
index 032c30ed4..b409e6252 100644
--- a/controller/nucleus/plugin.go
+++ b/controller/nucleus/plugin.go
@@ -3,6 +3,7 @@ package nucleus
 import (
 	"encoding/json"
 	"fmt"
+	"os"
 	"os/exec"
 	"path/filepath"
 
@@ -30,7 +31,7 @@ func NewPlugin(id uuid.UUID, execPath string) (*Plugin, error) {
 	client := hcplugin.NewClient(&hcplugin.ClientConfig{
 		HandshakeConfig:  shared.Handshake,
 		Plugins:          shared.PluginMap,
-		Cmd:              exec.Command("sh", "-c", filepath.Join(execPath, util.PluginExecutableName)),
+		Cmd:              exec.Command(filepath.Join(execPath, util.PluginExecutableName)),
 		AllowedProtocols: []hcplugin.Protocol{hcplugin.ProtocolGRPC},
 	})
 
@@ -129,6 +130,15 @@ func (p *Plugin) ReattachConfig() *hcplugin.ReattachConfig {
 	return p.client.ReattachConfig()
 }
 
+// Remove ensures that the Plugin is killed and the corresponding files are
+// removed.
+func (p *Plugin) Remove() error {
+	// stop the running plugins process
+	p.Close()
+	// remove the plugins folder
+	return os.RemoveAll(p.ExecPath())
+}
+
 // State returns the current state of the plugin.
 // Different states of the plugin can be:
 //   - created
@@ -165,6 +175,7 @@ func (p *Plugin) Restart() error {
 
 // Close ends the execution of the plugin.
 func (p *Plugin) Close() {
+	// end the plugin process
 	p.client.Kill()
 }
 
diff --git a/controller/nucleus/pluginService.go b/controller/nucleus/pluginService.go
index ed7a7ab68..3a1319efd 100644
--- a/controller/nucleus/pluginService.go
+++ b/controller/nucleus/pluginService.go
@@ -106,7 +106,7 @@ func (s *PluginService) Delete(pluginToDelete plugin.Plugin) error {
 	}
 
 	// stop the plugin
-	pluginToDelete.GetClient().Kill()
+	pluginToDelete.Close()
 
 	if err := s.eventService.PublishEvent(PluginEventTopic, event.NewDeleteEvent(pluginToDelete.ID())); err != nil {
 		log.Error(err)
@@ -167,10 +167,6 @@ func (s *PluginService) RequestPlugin(requestID uuid.UUID) (plugin.Plugin, error
 		return nil, err
 	}
 
-	if err := s.Add(plugin); err != nil {
-		return nil, err
-	}
-
 	return plugin, nil
 }
 
-- 
GitLab