Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
cli-handling.go 8.03 KiB
/*
	This file contains the grpc cli server-side calls.
    Functions here should call other functions in charge of the
	particular task.
*/

package nucleus

import (
	"context"
	"errors"
	"io"
	"net"
	"os"
	"sync"

	"github.com/google/uuid"
	"github.com/spf13/viper"

	pb "code.fbi.h-da.de/cocsn/gosdn/api/proto"
	"code.fbi.h-da.de/cocsn/gosdn/forks/goarista/gnmi"
	gpb "github.com/openconfig/gnmi/proto/gnmi"
	log "github.com/sirupsen/logrus"
	"google.golang.org/grpc"
	"google.golang.org/grpc/health"
	healthpb "google.golang.org/grpc/health/grpc_health_v1"
	"google.golang.org/protobuf/types/known/emptypb"
)

type logConnection struct {
	stream pb.GrpcCli_CreateLogStreamServer
	id     string
	active bool
	error  chan error
}

// server is used to implement the grcp cli server
type server struct {
	pb.UnimplementedGrpcCliServer
	core           *Core
	logConnections []*logConnection
}

var srv *server

type buf []byte

func (b *buf) Write(p []byte) (n int, err error) {

	reply := pb.LogReply{Log: string(p)}
	srv.BroadcastLog(&reply)
	return len(p), nil
}

func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
	log.Info("Received: ", in.GetName())
	return &pb.HelloReply{Message: "Hello " + in.GetName(), GoSDNInfo: "goSDN in version: DEVELOP"}, nil
}

// CreateLogStream creates a continuous stream between client and server to send goSDN logs
func (s *server) CreateLogStream(req *emptypb.Empty, stream pb.GrpcCli_CreateLogStreamServer) error {
	conn := &logConnection{
		stream: stream,
		active: true,
		error:  make(chan error),
	}
	s.logConnections = append(s.logConnections, conn)

	return <-conn.error
}

func (s *server) BroadcastLog(log *pb.LogReply) {
	wait := sync.WaitGroup{}
	done := make(chan int)

	for _, conn := range s.logConnections {
		wait.Add(1)

		go func(conn *logConnection) {
			defer wait.Done()
			if conn.active {
				err := conn.stream.Send(log)

				if err != nil {
					conn.active = false
					conn.error <- err
				}
			}
		}(conn)
	}

	go func() {
		//blocks until all send routines are finished
		wait.Wait()
		close(done)
	}()

	<-done
}

func (s *server) Shutdown(ctx context.Context, in *pb.ShutdownRequest) (*pb.ShutdownReply, error) {
	log.Info("Shutdown Received: ", in.GetName())
	s.core.IsRunning <- false
	return &pb.ShutdownReply{Message: "Shutdown " + in.GetName()}, nil
}

func getCLIGoing(core *Core) {

	var (
		logConnections []*logConnection
		logBuffer      buf
		system         = ""
	)

	log.Info("Starting: GetCLIGoing")
	// Boot-up the control interface for the cli
	cliControlListener, err := net.Listen("tcp", viper.GetString("socket"))
	if err != nil {
		log.Fatal(err)
	}

	cliControlServer := grpc.NewServer()
	healthCheck := health.NewServer()
	srv = &server{core: core, logConnections: logConnections}

	//TODO: move?
	wrt := io.MultiWriter(os.Stdout, &logBuffer)
	log.SetOutput(wrt)

	healthpb.RegisterHealthServer(cliControlServer, healthCheck)
	pb.RegisterGrpcCliServer(cliControlServer, srv)

	healthCheck.SetServingStatus(system, healthpb.HealthCheckResponse_SERVING)

	if err := cliControlServer.Serve(cliControlListener); err != nil {
		log.Fatal(err)
	}
}

// CreatePND creates a new PND and adds it to the principalNetworkDomain map of
// the core
func (s *server) CreatePND(ctx context.Context, in *pb.CreatePNDRequest) (*pb.CreatePNDReply, error) {
	log.Info("Received: Create a PND with the name", in.GetName())
	sbi, err := s.core.sbic.get(uuid.New())
	if err != nil {
		return nil, err
	}
	id := uuid.New()
	pnd, err := NewPNDwithId(in.GetName(), in.GetDescription(), id, sbi.(SouthboundInterface))
	if err != nil {
		log.Error(err)
		return &pb.CreatePNDReply{Message: err.Error()}, err
	}
	if err := s.core.pndc.add(pnd); err != nil {
		return nil, err
	}

	return &pb.CreatePNDReply{Message: "Created new PND: " + id.String()}, nil
}

// deprecated
// Subject to change, using discontinued full device access
// GetAllPNDs is a request to get all currently registered PNDs and returns a slim
// variant of PNDs and their respective devices
func (s *server) GetAllPNDs(ctx context.Context, in *emptypb.Empty) (*pb.AllPNDsReply, error) {
	log.Info("Received: Get all PNDs")
	var pnds []*pb.PND
	for _, uuidPND := range s.core.pndc.UUIDs() {
		pnd, err := s.core.pndc.get(uuidPND)
		if err != nil {
			log.Error(err)
			continue
		}
		var devices []*pb.Device
		for uuidDevice, d := range pnd.(*pndImplementation).devices.store {
			device, ok := d.(*Device)
			if !ok {
				log.Error(&ErrInvalidTypeAssertion{
					v: d,
					t: "Device",
				})
			}
			tmpDevice := pb.Device{
				Uuid:     uuidDevice.String(),
				Address:  device.Config.Address,
				Username: device.Config.Username,
				Password: device.Config.Password}
			devices = append(devices, &tmpDevice)
		}
		sbi, err := s.core.sbic.get(pnd.GetSBIs().(*sbiStore).UUIDs()[0])
		if err != nil {
			log.Error(err)
			continue
		}
		tmpPND := pb.PND{
			Uuid:        uuidPND.String(),
			Name:        pnd.GetName(),
			Description: pnd.GetDescription(),
			Sbi:         sbi.SbiIdentifier(),
			Devices:     devices,
		}
		pnds = append(pnds, &tmpPND)
	}
	return &pb.AllPNDsReply{Pnds: pnds}, nil
}

// GetAllSBINames returns all registered SBIs from core.
func (s *server) GetAllSBINames(ctx context.Context, in *emptypb.Empty) (*pb.AllSBINamesReply, error) {
	var sbiNames []string
	for _, uuidDevice := range s.core.sbic.UUIDs() {
		s, err := s.core.sbic.get(uuidDevice)
		if err != nil {
			log.Error(err)
			continue
		}
		sbiNames = append(sbiNames, s.SbiIdentifier())
	}
	return &pb.AllSBINamesReply{SbiNames: sbiNames}, nil
}

// AddDevice adds a new Device to a specific PND
// currently this is only working with gnmi transports
func (s *server) AddDevice(ctx context.Context, in *pb.AddDeviceRequest) (*pb.AddDeviceReply, error) {
	log.Info("Received: AddDevice")
	uuidPND, err := uuid.Parse(in.UuidPND)
	if err != nil {
		return &pb.AddDeviceReply{Message: err.Error()}, err
	}
	pnd, err := s.core.pndc.get(uuidPND)
	if err != nil {
		log.Error(err)
		return &pb.AddDeviceReply{Message: err.Error()}, err
	}
	// TODO: Add notion of default SBI to PND or solve differently
	uuidSbi := pnd.GetSBIs().(*sbiStore).UUIDs()[0]
	sbi, err := s.core.sbic.get(uuidSbi)
	if err != nil {
		log.Error(err)
		return &pb.AddDeviceReply{Message: err.Error()}, err
	}

	//TODO: could the transport and the related config be created in device?
	transport := &Gnmi{SetNode: sbi.SetNode()}
	cfg := &gnmi.Config{
		Addr:     in.Device.Address,
		Username: in.Device.Username,
		Password: in.Device.Password,
		Encoding: gpb.Encoding_JSON_IETF,
	}
	transport.SetConfig(cfg)

	newDevice := NewDevice(sbi, in.Device.Address, in.Device.Username,
		in.Device.Password, transport)

	err = pnd.AddDevice(newDevice)
	if err != nil {
		log.Error(err)
		return &pb.AddDeviceReply{Message: err.Error()}, err
	}

	return &pb.AddDeviceReply{Message: "Added new Device: " + newDevice.Config.Uuid.String()}, err
}

// HandleDeviceGetRequest handles a GET request via pnd.Request()
func (s *server) HandleDeviceGetRequest(ctx context.Context, in *pb.DeviceGetRequest) (*pb.DeviceGetReply, error) {
	log.Info("Received: HandleDeviceGetRequest")
	uuidPND, err := uuid.Parse(in.GetUuidPND())
	if err != nil {
		log.Info(err)
		return &pb.DeviceGetReply{Message: err.Error()}, err
	}
	uuidDevice, err := uuid.Parse(in.GetUuidDevice())
	if err != nil {
		log.Info(err)
		return &pb.DeviceGetReply{Message: err.Error()}, err
	}
	pnd, err := s.core.pndc.get(uuidPND)
	if err != nil {
		err := errors.New("Couldnt find PND: UUID is wrong")
		log.Info(err)
		return &pb.DeviceGetReply{Message: err.Error()}, err
	}
	//check if the device exists
	if !pnd.ContainsDevice(uuidDevice) {
		err := errors.New("Couldnt find device: UUID is wrong")
		log.Info(err)
		return &pb.DeviceGetReply{Message: err.Error()}, err
	}

	//GET request for the provided path
	err = pnd.Request(uuidDevice, in.GetPath())
	if err != nil {
		log.Info(err)
		return &pb.DeviceGetReply{Message: err.Error()}, err
	}

	d, err := pnd.MarshalDevice(uuidDevice)
	if err != nil {
		log.Info(err)
		return &pb.DeviceGetReply{Message: err.Error()}, err
	}

	return &pb.DeviceGetReply{Message: d}, nil
}