Skip to content
Snippets Groups Projects
pluginService.go 6.27 KiB
Newer Older
  • Learn to ignore specific revisions
  • 	"os"
    	"path/filepath"
    	"strings"
    	"time"
    
    	rpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-registry"
    	"code.fbi.h-da.de/danet/gosdn/controller/customerrs"
    
    	"code.fbi.h-da.de/danet/gosdn/controller/event"
    	eventInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/event"
    	"code.fbi.h-da.de/danet/gosdn/controller/interfaces/plugin"
    
    	"code.fbi.h-da.de/danet/gosdn/controller/nucleus/util"
    
    	"code.fbi.h-da.de/danet/gosdn/controller/store"
    	"github.com/google/uuid"
    	hcplugin "github.com/hashicorp/go-plugin"
    	log "github.com/sirupsen/logrus"
    
    	"github.com/spf13/viper"
    	"google.golang.org/grpc"
    
    )
    
    const (
    	// PluginEventTopic is the used topic for plugin related entity changes.
    	PluginEventTopic = "plugin"
    )
    
    // PluginService provides a plugin service implementation.
    type PluginService struct {
    	pluginStore             plugin.Store
    	eventService            eventInterfaces.Service
    	createPluginFromStoreFn func(plugin.LoadedPlugin) (plugin.Plugin, error)
    
    	pluginRegistryClient    rpb.PluginRegistryServiceClient
    
    }
    
    // NewPluginService creates a plugin service.
    
    func NewPluginService(pluginStore plugin.Store, eventService eventInterfaces.Service, createPluginFromStoreFn func(plugin.LoadedPlugin) (plugin.Plugin, error), pluginRegistryClient rpb.PluginRegistryServiceClient) plugin.Service {
    
    	return &PluginService{
    		pluginStore:             pluginStore,
    		eventService:            eventService,
    		createPluginFromStoreFn: createPluginFromStoreFn,
    
    		pluginRegistryClient:    pluginRegistryClient,
    
    	}
    }
    
    // Get takes a Plugin's UUID or name and returns the Plugin.
    func (s *PluginService) Get(query store.Query) (plugin.Plugin, error) {
    	loadedPlugin, err := s.pluginStore.Get(query)
    	if err != nil {
    		return nil, err
    	}
    
    	plugin, err := s.createPluginFromStore(loadedPlugin)
    	if err != nil {
    		return nil, err
    	}
    
    	return plugin, nil
    }
    
    // GetAll returns all stored plugins.
    func (s *PluginService) GetAll() ([]plugin.Plugin, error) {
    	var plugins []plugin.Plugin
    
    	loadedPlugins, err := s.pluginStore.GetAll()
    	if err != nil {
    		return nil, err
    	}
    
    	for _, loadedPlugin := range loadedPlugins {
    		plugin, err := s.createPluginFromStore(loadedPlugin)
    		if err != nil {
    			return nil, err
    		}
    
    		plugins = append(plugins, plugin)
    	}
    
    	return plugins, nil
    }
    
    // Add adds a plugin to the plugin store.
    func (s *PluginService) Add(pluginToAdd plugin.Plugin) error {
    	err := s.pluginStore.Add(pluginToAdd)
    	if err != nil {
    		return err
    	}
    
    	if err := s.eventService.PublishEvent(PluginEventTopic, event.NewAddEvent(pluginToAdd.ID())); err != nil {
    		log.Error(err)
    	}
    
    	return nil
    }
    
    // Delete deletes a plugin from the plugin store.
    func (s *PluginService) Delete(pluginToDelete plugin.Plugin) error {
    	err := s.pluginStore.Delete(pluginToDelete)
    	if err != nil {
    		return err
    	}
    
    	// stop the plugin
    	pluginToDelete.GetClient().Kill()
    
    	if err := s.eventService.PublishEvent(PluginEventTopic, event.NewDeleteEvent(pluginToDelete.ID())); err != nil {
    		log.Error(err)
    	}
    
    	return nil
    }
    
    func (s *PluginService) createPluginFromStore(loadedPlugin plugin.LoadedPlugin) (plugin.Plugin, error) {
    	plugin, err := s.createPluginFromStoreFn(loadedPlugin)
    	if err != nil {
    		if errors.Is(err, hcplugin.ErrProcessNotFound) {
    
    			plugin, err = NewPlugin(uuid.MustParse(loadedPlugin.ID), loadedPlugin.ExecPath)
    
    			if err != nil {
    				return nil, err
    			}
    			err := s.pluginStore.Update(plugin)
    			if err != nil {
    				return nil, err
    			}
    		} else {
    			return nil, err
    		}
    	}
    
    	return plugin, nil
    }
    
    
    // RequestPlugin request a plugin from the plugin-registry.
    func (s *PluginService) RequestPlugin(requestID uuid.UUID) (plugin.Plugin, error) {
    	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*1)
    	defer cancel()
    
    	pluginDownloadRequest := &rpb.GetDownloadRequest{
    		Timestamp: time.Now().UnixNano(),
    		Id:        requestID.String(),
    	}
    
    
    	folderName := viper.GetString("plugin-folder")
    	path := filepath.Join(folderName, requestID.String())
    	if _, err := os.Stat(filepath.Join(path, util.PluginExecutableName)); errors.Is(err, fs.ErrNotExist) {
    		dClient, err := s.pluginRegistryClient.Download(ctx, pluginDownloadRequest)
    		if err != nil {
    			return nil, err
    		}
    
    		if err := saveStreamToFile(dClient, util.BundledPluginName, requestID); err != nil {
    			return nil, err
    		}
    
    		if err := util.UnzipPlugin(requestID); err != nil {
    			return nil, err
    		}
    
    	plugin, err := NewPlugin(uuid.New(), path)
    
    	if err != nil {
    		return nil, err
    	}
    
    	if err := s.Add(plugin); err != nil {
    		return nil, err
    	}
    
    	return plugin, nil
    }
    
    // StreamClient allows to distinguish between the different ygot
    // generated GoStruct clients, which return a stream of bytes.
    type StreamClient interface {
    	Recv() (*rpb.GetDownloadPayload, error)
    	grpc.ClientStream
    }
    
    // saveStreamToFile takes a StreamClient and processes the included gRPC
    // stream. A file with the provided filename is created within the goSDN's
    // 'plugin-folder'. Each file is stored in its own folder based on a new
    // uuid.UUID.
    func saveStreamToFile(sc StreamClient, filename string, id uuid.UUID) (err error) {
    	folderName := viper.GetString("plugin-folder")
    	path := filepath.Join(folderName, id.String(), filename)
    
    	// clean path to prevent attackers to get access to to directories elsewhere on the system
    	path = filepath.Clean(path)
    	if !strings.HasPrefix(path, folderName) {
    		return &customerrs.InvalidParametersError{
    			Func:  saveStreamToFile,
    			Param: path,
    		}
    	}
    
    	// create the directory hierarchy based on the path
    	if err := os.MkdirAll(filepath.Dir(path), 0770); err != nil {
    		return err
    	}
    	// create the gostructs.go file at path
    	f, err := os.Create(path)
    	if err != nil {
    		return err
    	}
    
    	defer func() {
    		if ferr := f.Close(); ferr != nil {
    			fErrString := ferr.Error()
    			err = fmt.Errorf("InternalError=%w error closing file:%+s", err, fErrString)
    		}
    	}()
    
    	// receive byte stream
    	for {
    		payload, err := sc.Recv()
    		if err != nil {
    			if errors.Is(err, io.EOF) {
    				break
    			}
    			closeErr := sc.CloseSend()
    			if closeErr != nil {
    				return closeErr
    			}
    
    			return err
    		}
    		n, err := f.Write(payload.Chunk)
    		if err != nil {
    			closeErr := sc.CloseSend()
    			if closeErr != nil {
    				return closeErr
    			}
    
    			return err
    		}
    		log.WithField("n", n).Trace("wrote bytes")
    	}
    	if err := f.Sync(); err != nil {
    		return err
    	}
    
    	return nil
    }