package interfaces

import (
	"fmt"
	"net"

	gnmitargetygot "code.fbi.h-da.de/danet/gnmi-target/examples/example01/model"
	"code.fbi.h-da.de/danet/gnmi-target/examples/example01/osclient"
	"code.fbi.h-da.de/danet/gnmi-target/examples/example01/osclient/additions"
	"github.com/openconfig/gnmi/proto/gnmi"
	"github.com/openconfig/ygot/ygot"
	log "github.com/sirupsen/logrus"
)

// InterfacesHandler is the implementation of a gnmitarget.PathHandler.
type InterfacesHandler struct {
	name     string
	paths    map[string]struct{}
	osClient osclient.Osclient
}

func NewInterfacesHandler() *InterfacesHandler {
	return &InterfacesHandler{
		name: "openconfig-interfaces-handler",
		paths: map[string]struct{}{
			"/interfaces": struct{}{},
		},
		osClient: osclient.NewOsClient(),
	}
}

func (yh *InterfacesHandler) Name() string {
	return yh.name
}

func (yh *InterfacesHandler) Paths() map[string]struct{} {
	return yh.paths
}

func (yh *InterfacesHandler) Init(c ygot.ValidatedGoStruct) error {
	config, ok := c.(*gnmitargetygot.Gnmitarget)
	if !ok {
		return fmt.Errorf("failed type assertion for config %T", (*gnmitargetygot.Gnmitarget)(nil))
	}

	// needed for interfaces and network instances
	localInterfaces, err := yh.osClient.GetInterfaces()
	if err != nil {
		return err
	}

	confInterfaces := config.GetOrCreateInterfaces()

	for _, localInterface := range localInterfaces {
		if err := updateOrCreateInterface(confInterfaces, localInterface); err != nil {
			return err
		}
	}

	// subscribe to interfaces
	interfaceChannel, err := yh.osClient.SubscribeToInterfaces()
	if err != nil {
		return err
	}
	go func() {
		for {
			select {
			case update := <-interfaceChannel:
				if err := updateOrCreateInterface(confInterfaces, update); err != nil {
					fmt.Println("Error within interface subscription goroutine.")
				}
			}
		}
	}()

	return nil
}

func (yh *InterfacesHandler) Update(c ygot.ValidatedGoStruct, updates []*gnmi.Update) error {
	fmt.Println("Update request received for ", yh.name)
	config, ok := c.(*gnmitargetygot.Gnmitarget)
	if !ok {
		return fmt.Errorf("failed type assertion for config %T", (*gnmitargetygot.Gnmitarget)(nil))
	}

	interfaces := config.GetInterfaces()

	if interfaces != nil {
		if intfMap := interfaces.Interface; intfMap != nil {
			for _, intf := range intfMap {
				osInterface := &additions.Interface{}
				osInterface.Ipv4Addresses = make([]additions.IPAddress, 0)
				osInterface.Ipv6Addresses = make([]additions.IPAddress, 0)

				if state := intf.GetState(); state != nil {
					osInterface.Index = state.Ifindex
					osInterface.OperState = state.OperStatus
					osInterface.AdminStatus = state.AdminStatus
					osInterface.LoopbackMode = state.LoopbackMode
				}

				if config := intf.GetConfig(); config != nil {
					osInterface.Name = config.Name
					osInterface.Type = config.Type
					osInterface.MTU = config.Mtu
				}

				if *osInterface.Name != "lo" && *osInterface.Name != "wlan0" {
					if subinterfaces := intf.Subinterfaces; subinterfaces != nil {
						if subintfMap := subinterfaces.Subinterface; subintfMap != nil {
							for _, subintf := range subintfMap {
								if ipv4s := subintf.GetIpv4(); ipv4s != nil {
									for _, addr := range subintf.Ipv4.Addresses.Address {
										if addr.Ip != nil && addr.Config.PrefixLength != nil {
											osInterface.Ipv4Addresses = append(osInterface.Ipv4Addresses, additions.IPAddress{
												IPNet: net.IPNet{
													IP:   net.ParseIP(addr.GetIp()),
													Mask: net.CIDRMask(int(addr.GetConfig().GetPrefixLength()), 32),
												},
											})
										}
									}
								}

								if ipv6s := subintf.GetIpv6(); ipv6s != nil {
									for _, addr := range subintf.Ipv6.Addresses.Address {
										if addr.Ip != nil && addr.Config.PrefixLength != nil {
											osInterface.Ipv6Addresses = append(osInterface.Ipv6Addresses, additions.IPAddress{
												IPNet: net.IPNet{
													IP:   net.ParseIP(addr.GetIp()),
													Mask: net.CIDRMask(int(addr.GetConfig().GetPrefixLength()), 32),
												},
											})
										}
									}
								}
							}
						}
					}
					if err := yh.osClient.SetInterface(osInterface); err != nil {
						log.Debug("Failed to set interface: ", err)
						return err
					}
				}
			}
		}
	}

	return nil
}

func updateOrCreateInterface(confInterfaces *gnmitargetygot.OpenconfigInterfaces_Interfaces, localInterface *additions.Interface) error {
	iface := confInterfaces.GetOrCreateInterface(*localInterface.Name)
	state := iface.GetOrCreateState()
	config := iface.GetOrCreateConfig()

	state.Ifindex = localInterface.Index
	iface.Name = localInterface.Name

	//base ethernet interface type would be 6 (see iana-if-type.yang)
	config.Type = localInterface.Type
	config.Mtu = localInterface.MTU
	config.Name = localInterface.Name

	state.OperStatus = localInterface.OperState
	state.AdminStatus = localInterface.AdminStatus
	state.LoopbackMode = localInterface.LoopbackMode

	for i, addr := range localInterface.Ipv4Addresses {
		subiface := iface.GetOrCreateSubinterfaces().GetOrCreateSubinterface(uint32(i))
		subifaceConfig := subiface.GetOrCreateConfig()
		subifaceConfig.Index = ygot.Uint32(uint32(i))
		subiface.GetOrCreateState()

		ipv4 := subiface.GetOrCreateIpv4()
		ipv4Addr := ipv4.GetOrCreateAddresses().GetOrCreateAddress(addr.IP.String())
		ipv4AddrConf := ipv4Addr.GetOrCreateConfig()
		ipv4AddrConf.Ip = ygot.String(addr.IP.String())

		prefix, _ := addr.IPNet.Mask.Size()
		convPrefix := uint8(prefix)
		ipv4AddrConf.PrefixLength = &convPrefix
	}

	for i, addr := range localInterface.Ipv6Addresses {
		subiface := iface.GetOrCreateSubinterfaces().GetOrCreateSubinterface(uint32(i))
		subiface.GetOrCreateConfig()
		subiface.GetOrCreateState()

		ipv6 := subiface.GetOrCreateIpv6()
		ipv6Addr := ipv6.GetOrCreateAddresses().GetOrCreateAddress(addr.IP.String())
		ipv6AddrConf := ipv6Addr.GetOrCreateConfig()
		ipv6AddrConf.Ip = ygot.String(addr.IP.String())

		prefix, _ := addr.IPNet.Mask.Size()
		convPrefix := uint8(prefix)
		ipv6AddrConf.PrefixLength = &convPrefix
	}

	//validate struct
	if err := confInterfaces.Validate(); err != nil {
		return err
	}

	return nil
}