// Copyright (c) 2017 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.

package gnmi

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"fmt"
	"math"
	"net"
	"os"

	"io/ioutil"
	"strings"

	"github.com/golang/protobuf/proto"
	pb "github.com/openconfig/gnmi/proto/gnmi"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/encoding/gzip"
	"google.golang.org/grpc/metadata"
)

const (
	defaultPort = "6030"
	// HostnameArg is the value to be replaced by the actual hostname
	HostnameArg = "HOSTNAME"
)

// PublishFunc is the method to publish responses
type PublishFunc func(addr string, message proto.Message)

// ParseHostnames parses a comma-separated list of names and replaces HOSTNAME with the current
// hostname in it
func ParseHostnames(list string) ([]string, error) {
	items := strings.Split(list, ",")
	hostname, err := os.Hostname()
	if err != nil {
		return nil, err
	}
	names := make([]string, len(items))
	for i, name := range items {
		if name == HostnameArg {
			name = hostname
		}
		names[i] = name
	}
	return names, nil
}

// Config is the gnmi.Client config
type Config struct {
	Addr        string
	CAFile      string
	CertFile    string
	KeyFile     string
	Password    string
	Username    string
	TLS         bool
	Compression string
	DialOptions []grpc.DialOption
	Token       string
	Encoding    pb.Encoding
}

// SubscribeOptions is the gNMI subscription request options
type SubscribeOptions struct {
	UpdatesOnly       bool
	Prefix            string
	Mode              string
	StreamMode        string
	SampleInterval    uint64
	SuppressRedundant bool
	HeartbeatInterval uint64
	Paths             [][]string
	Origin            string
	Target            string
}

// accessTokenCred implements credentials.PerRPCCredentials, the gRPC
// interface for credentials that need to attach security information
// to every RPC.
type accessTokenCred struct {
	bearerToken string
}

// newAccessTokenCredential constructs a new per-RPC credential from a token.
func newAccessTokenCredential(token string) credentials.PerRPCCredentials {
	bearerFmt := "Bearer %s"
	return &accessTokenCred{bearerToken: fmt.Sprintf(bearerFmt, token)}
}

func (a *accessTokenCred) GetRequestMetadata(ctx context.Context,
	uri ...string) (map[string]string, error) {
	authHeader := "Authorization"
	return map[string]string{
		authHeader: a.bearerToken,
	}, nil
}

func (a *accessTokenCred) RequireTransportSecurity() bool { return true }

// DialContext connects to a gnmi service and returns a client
func DialContext(ctx context.Context, cfg *Config) (pb.GNMIClient, error) {
	opts := append([]grpc.DialOption(nil), cfg.DialOptions...)

	switch cfg.Compression {
	case "":
	case "gzip":
		opts = append(opts, grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)))
	default:
		return nil, fmt.Errorf("unsupported compression option: %q", cfg.Compression)
	}

	if cfg.TLS || cfg.CAFile != "" || cfg.CertFile != "" || cfg.Token != "" {
		tlsConfig := &tls.Config{
			MinVersion: tls.VersionTLS12,
		}
		if cfg.CAFile != "" {
			b, err := ioutil.ReadFile(cfg.CAFile)
			if err != nil {
				return nil, err
			}
			cp := x509.NewCertPool()
			if !cp.AppendCertsFromPEM(b) {
				return nil, fmt.Errorf("credentials: failed to append certificates")
			}
			tlsConfig.RootCAs = cp
		} else {
			tlsConfig.InsecureSkipVerify = true
		}
		if cfg.CertFile != "" {
			if cfg.KeyFile == "" {
				return nil, fmt.Errorf("please provide both -certfile and -keyfile")
			}
			cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
			if err != nil {
				return nil, err
			}
			tlsConfig.Certificates = []tls.Certificate{cert}
		}
		if cfg.Token != "" {
			opts = append(opts,
				grpc.WithPerRPCCredentials(newAccessTokenCredential(cfg.Token)))
		}
		opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
	} else {
		opts = append(opts, grpc.WithInsecure())
	}

	dial := func(ctx context.Context, addrIn string) (conn net.Conn, err error) {
		var network, addr string

		split := strings.Split(addrIn, "://")
		if l := len(split); l == 2 {
			network = split[0]
			addr = split[1]
		} else {
			network = "tcp"
			addr = split[0]
		}

		conn, err = (&net.Dialer{}).DialContext(ctx, network, addr)
		return
	}

	opts = append(opts,
		grpc.WithContextDialer(dial),

		// Allows received protobuf messages to be larger than 4MB
		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
	)

	grpcconn, err := grpc.DialContext(ctx, cfg.Addr, opts...)
	if err != nil {
		return nil, fmt.Errorf("failed to dial: %s", err)
	}

	return pb.NewGNMIClient(grpcconn), nil
}

// Dial connects to a gnmi service and returns a client
func Dial(cfg *Config) (pb.GNMIClient, error) {
	return DialContext(context.Background(), cfg)
}

// NewContext returns a new context with username and password
// metadata if they are set in cfg.
func NewContext(ctx context.Context, cfg *Config) context.Context {
	if cfg.Username != "" {
		ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(
			"username", cfg.Username,
			"password", cfg.Password))
	}
	return ctx
}

// NewGetRequest returns a GetRequest for the given paths
func NewGetRequest(ctx context.Context, paths [][]string, origin string) (*pb.GetRequest, error) {
	val := ctx.Value("config")
	cfg, ok := val.(*Config)
	if !ok {
		return nil, errors.New("invalid type assertion")
	}
	req := &pb.GetRequest{
		Path:     make([]*pb.Path, len(paths)),
		Encoding: cfg.Encoding,
	}
	for i, p := range paths {
		gnmiPath, err := ParseGNMIElements(p)
		if err != nil {
			return nil, err
		}
		req.Path[i] = gnmiPath
		req.Path[i].Origin = origin
	}
	return req, nil
}

// NewSubscribeRequest returns a SubscribeRequest for the given paths
func NewSubscribeRequest(subscribeOptions *SubscribeOptions) (*pb.SubscribeRequest, error) {
	var mode pb.SubscriptionList_Mode
	switch subscribeOptions.Mode {
	case "once":
		mode = pb.SubscriptionList_ONCE
	case "poll":
		mode = pb.SubscriptionList_POLL
	case "":
		fallthrough
	case "stream":
		mode = pb.SubscriptionList_STREAM
	default:
		return nil, fmt.Errorf("subscribe mode (%s) invalid", subscribeOptions.Mode)
	}

	var streamMode pb.SubscriptionMode
	switch subscribeOptions.StreamMode {
	case "on_change":
		streamMode = pb.SubscriptionMode_ON_CHANGE
	case "sample":
		streamMode = pb.SubscriptionMode_SAMPLE
	case "":
		fallthrough
	case "target_defined":
		streamMode = pb.SubscriptionMode_TARGET_DEFINED
	default:
		return nil, fmt.Errorf("subscribe stream mode (%s) invalid", subscribeOptions.StreamMode)
	}

	prefixPath, err := ParseGNMIElements(SplitPath(subscribeOptions.Prefix))
	if err != nil {
		return nil, err
	}
	subList := &pb.SubscriptionList{
		Subscription: make([]*pb.Subscription, len(subscribeOptions.Paths)),
		Mode:         mode,
		UpdatesOnly:  subscribeOptions.UpdatesOnly,
		Prefix:       prefixPath,
	}
	if subscribeOptions.Target != "" {
		if subList.Prefix == nil {
			subList.Prefix = &pb.Path{}
		}
		subList.Prefix.Target = subscribeOptions.Target
	}
	for i, p := range subscribeOptions.Paths {
		gnmiPath, err := ParseGNMIElements(p)
		if err != nil {
			return nil, err
		}
		gnmiPath.Origin = subscribeOptions.Origin
		subList.Subscription[i] = &pb.Subscription{
			Path:              gnmiPath,
			Mode:              streamMode,
			SampleInterval:    subscribeOptions.SampleInterval,
			SuppressRedundant: subscribeOptions.SuppressRedundant,
			HeartbeatInterval: subscribeOptions.HeartbeatInterval,
		}
	}
	return &pb.SubscribeRequest{Request: &pb.SubscribeRequest_Subscribe{
		Subscribe: subList}}, nil
}
