Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
config.go 5.06 KiB
package config

import (
	"os"
	"time"

	"github.com/google/uuid"
	log "github.com/sirupsen/logrus"
	"github.com/spf13/viper"
)

const (
	defaultTimeOutDuration10minutes = time.Minute * 10
	basePNDUUIDKey                  = "basePNDUUID"
	changeTimeoutKey                = "GOSDN_CHANGE_TIMEOUT"
	databaseConnectionKey           = "databaseConnection"
	filesystemPathToStores          = "filesystemPathToStores"
	jwtDurationKey                  = "defaultJWTDuration"
	defaultJWTDuration              = time.Hour * 24
	jwtSecretKey                    = "jwtSecret"
	gNMISubscriptionsFilePathKey    = "gNMISubscriptionsPath"
	maxTokensPerUserKey             = "maxTokensPerUser"
	defaultMaxTokensPerUserDefault  = 100

	// RabbitMQ Broker.
	amqpPrefixKey   = "amqpPrefix"
	amqpUserKey     = "amqpUser"
	amqpPasswordKey = "amqpPassword"
	amqpHostKey     = "amqpHost"
	amqpPortKey     = "amqpPort"

	// TLS.
	tlsCertFileKey   = "tlsCertFile"
	tlsKeyFileKey    = "tlsKeyFile"
	tlsCACertFileKey = "tlsCACertFile"
)

// BasePndUUID is an uuid for the base PND.
var BasePndUUID uuid.UUID

// ChangeTimeout is the default timeout for a change.
var ChangeTimeout time.Duration

// LogLevel ist the default log level.
var LogLevel log.Level

// DatabaseConnection holds the credentials and address of the used database.
var DatabaseConnection string

// FilesystemPathToStores determines in which folder the stores should be saved.
var FilesystemPathToStores = "stores_testing"

// JWTDuration determines how long a jwt is valid.
var JWTDuration time.Duration

// JWTSecret determines the scret that is used to sign tokens.
var JWTSecret string

// AMQPPrefix is the amqp prefix.
var AMQPPrefix string

// AMQPUser is the amqp user.
var AMQPUser string

// AMQPPassword is the amqp user password.
var AMQPPassword string

// AMQPHost is the amqp host.
var AMQPHost string
// AMQPPort is the amqp port.
var AMQPPort string

// GNMISubscriptionsFilePath is the path to the file used for automated subscriptions.
var GNMISubscriptionsFilePath string

// CAFilePath is the path to the root CA file.
var CAFilePath string

// CertFilePath is the path to the signed certificate that the controller should use for TLS connections.
var CertFilePath string

// KeyFilePath si the path to the private key that the controller should use for TLS connections.
var KeyFilePath string

// MaxTokensPerUser is the maximum number of tokens a user can have. This determines the maximum of concurrent logged in sessions per user.
var MaxTokensPerUser int

// Init gets called on module import.
func Init() {
	err := InitializeConfig()
	if err != nil {
		log.Error("failed initialization of module import", err)
	}
}

// InitializeConfig loads the configuration.
func InitializeConfig() error {
	var err error

	basePNDUUIDFromViper, err := getUUIDFromViper(basePNDUUIDKey)
	if err != nil {
		return err
	}

	BasePndUUID = basePNDUUIDFromViper

	err = setChangeTimeout()
	if err != nil {
		return err
	}

	setLogLevel()

	DatabaseConnection = viper.GetString(databaseConnectionKey)

	FilesystemPathToStores = viper.GetString(filesystemPathToStores)
	if FilesystemPathToStores == "" {
		FilesystemPathToStores = "stores"
	}

	JWTDuration, err = getDurationFromViper(jwtDurationKey, "h")
	if err != nil {
		JWTDuration = defaultJWTDuration
	}

	JWTSecret = viper.GetString(jwtSecretKey)

	MaxTokensPerUser = viper.GetInt(maxTokensPerUserKey)
	if MaxTokensPerUser == 0 {
		MaxTokensPerUser = defaultMaxTokensPerUserDefault
	}

	GNMISubscriptionsFilePath = viper.GetString(gNMISubscriptionsFilePathKey)

	loadAMQPConfig()

	CAFilePath = viper.GetString(tlsCACertFileKey)

	CertFilePath = viper.GetString(tlsCertFileKey)

	KeyFilePath = viper.GetString(tlsKeyFileKey)

	if err := viper.WriteConfig(); err != nil {
		return err
	}

	return nil
}

// UseDatabase enables other modules to decide if they should use
// a database as backend.
func UseDatabase() bool {
	return len(DatabaseConnection) != 0
}

func getUUIDFromViper(viperKey string) (uuid.UUID, error) {
	UUIDAsString := viper.GetString(viperKey)
	if UUIDAsString == "" {
		newUUID := uuid.New()
		viper.Set(viperKey, newUUID.String())

		return newUUID, nil
	}

	parsedUUID, err := uuid.Parse(UUIDAsString)
	if err != nil {
		return uuid.Nil, err
	}

	return parsedUUID, nil
}

func setChangeTimeout() error {
	e := os.Getenv(changeTimeoutKey)
	if e != "" {
		changeTimeout, err := time.ParseDuration(e)
		if err != nil {
			log.Fatal(err)
		}
		ChangeTimeout = changeTimeout
	} else {
		ChangeTimeout = time.Minute * 10
	}

	return nil
}

func setLogLevel() {
	if os.Getenv("GOSDN_LOG") == "nolog" {
		LogLevel = log.PanicLevel
	} else {
		LogLevel = log.InfoLevel
	}
}

func getDurationFromViper(viperKey, unit string) (time.Duration, error) {
	durationString := viper.GetString(viperKey)

	duration, err := time.ParseDuration(durationString + unit)
	if err != nil {
		return 0, err
	}

	return duration, nil
}

func loadAMQPConfig() {
	AMQPPrefix = viper.GetString(amqpPrefixKey)
	AMQPUser = viper.GetString(amqpUserKey)
	AMQPPassword = viper.GetString(amqpPasswordKey)
	AMQPHost = viper.GetString(amqpHostKey)
	AMQPPort = viper.GetString(amqpPortKey)
}