package webhooks

import (
	"bytes"
	"crypto/hmac"
	"crypto/sha512"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"strconv"
	"strings"

	"go.uber.org/zap"
)

// 2Mi should be enough for everybody
const maxBodySize = 2 << 20

// Message describes a received web hook. The dynamic type of Payload depends
type Message struct {
	Event     string                 `json:"event"`
	Timestamp int64                  `json:"ts"`
	Payload   map[string]interface{} `json:"payload"`
}

// Receive returns a http.Handler that receives webhooks from BBBAtScale and
// sends their contents to the returned channel.
func Receive(opts ...ReceiverOption) (<-chan Message, http.Handler, error) {
	r := new(receiver)
	for _, opt := range opts {
		opt(r)
	}
	if r.ch == nil {
		r.ch = make(chan Message)
	}
	if r.log == nil {
		r.log = zap.NewNop()
	}

	return r.ch, r, nil
}

// A ReceiverOption can be passed to Receive to customize behaviour of a
// webhook Receiver.
type ReceiverOption func(*receiver)

// Authenticate enables authentication of incoming webhooks using HMAC-SHA512
// and the provided key.
func Authenticate(key []byte) ReceiverOption {
	return func(r *receiver) {
		r.macKey = key
	}
}

// WithLogger instructs the receiver to log messages to l.
func WithLogger(l *zap.Logger) ReceiverOption {
	return func(r *receiver) {
		r.log = l
	}
}

// The receiver is a http.Handler receiving web hooks from BBBatScale.
type receiver struct {
	ch  chan Message
	log *zap.Logger

	// Used for authenticating incoming requests. A SHA512-HMAC computed
	// over the request body must match the tag sent in X-Hook-Signature
	// header.
	macKey []byte
}

func (wr *receiver) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	log := wr.log.Sugar().With(
		"method", r.Method,
		"url", r.URL.String(),
		"remote", r.RemoteAddr,
		"ua", r.UserAgent(),
		"xff", r.Header.Get("X-Forwarded-For"),
	)

	switch r.Method {
	case "POST":
	default:
		w.Header().Set("Allow", "POST")
		if r.Method == "OPTIONS" {
			return
		}
		code := http.StatusMethodNotAllowed
		http.Error(w, http.StatusText(code), code)
		log.Info("method not allowed")
		return
	}

	p, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodySize))
	if err != nil {
		code := http.StatusBadRequest
		http.Error(w, http.StatusText(code), code)
		log.Warn("read", err)
		return
	}
	sig := r.Header.Get("X-Hook-Signature")
	ok := wr.verifyTag(sig, p)
	if !ok {
		code := http.StatusForbidden
		http.Error(w, http.StatusText(code), code)
		log.Warnw("invalid mac tag", "tag", sig)
		return
	}

	var v Message
	if err := json.Unmarshal(p, &v); err != nil {
		code := http.StatusBadRequest
		http.Error(w, http.StatusText(code), code)
		log.Warn("unmarshal: ", err)
		return
	}

	// Basic sanity check to avoid passing empty messages up the channel.
	if v.Event == "" {
		code := http.StatusBadRequest
		http.Error(w, http.StatusText(code), code)
		log.Warn("event missing from message")
		return
	}

	log.Infow("incoming hook", "event", v.Event)
	wr.ch <- v
}

func (wr *receiver) verifyTag(header string, body []byte) bool {
	if len(wr.macKey) == 0 {
		// request authentication disabled
		return true
	}
	v1, t, err := disassembleHookSignature(header)
	if err != nil {
		return false
	}
	tag, err := hex.DecodeString(v1)
	if err != nil {
		return false
	}

	var b bytes.Buffer
	_, err = fmt.Fprintf(&b, "%d.%s", t, body)
	if err != nil {
		return false
	}

	mac := hmac.New(sha512.New, wr.macKey)
	mac.Write(b.Bytes())
	expect := mac.Sum(nil)
	return hmac.Equal(tag, expect)
}

func disassembleHookSignature(s string) (v1 string, t int64, err error) {
	elems := strings.Split(s, ",")
	if len(elems) < 2 {
		return "", 0, fmt.Errorf(
			"invalid argument: need at least v1 and t parts: %q", s)
	}

	for _, e := range elems {
		kv := strings.SplitN(e, "=", 2)
		if len(kv) != 2 {
			return "", 0, fmt.Errorf(
				"invalid argument: expect k=v, got %q", e)
		}
		switch kv[0] {
		case "v1":
			if kv[1] == "" {
				return "", 0, errors.New(
					"invalid argument: missing value for v1")
			}
			v1 = kv[1]
		case "t":
			d, err := strconv.ParseInt(kv[1], 10, 64)
			if err != nil {
				return "", 0, fmt.Errorf(
					"invalid argument: t: %v", err)
			}
			if d <= 0 {
				return "", 0, fmt.Errorf(
					"invalid argument: t: %d", d)
			}
			t = d
		default:
			// skip keys we don't know how to handle
		}
	}

	if v1 == "" {
		return "", 0, errors.New("invalid argument: missing v1")
	}
	if t == 0 {
		return "", 0, errors.New("invalid argument: missing t")
	}

	return v1, t, nil
}