// Copyright 2019 Path Network, Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
	"context"
	"errors"
	"log/slog"
	"net"
	"net/netip"
	"sync/atomic"
	"syscall"
	"time"
)

type udpConnection struct {
	lastActivity   *int64
	clientAddr     *net.UDPAddr
	downstreamAddr *net.UDPAddr
	upstream       *net.UDPConn
	logger         *slog.Logger
}

func udpCloseAfterInactivity(conn *udpConnection, socketClosures chan<- string) {
	for {
		lastActivity := atomic.LoadInt64(conn.lastActivity)
		<-time.After(Opts.UDPCloseAfter)
		if atomic.LoadInt64(conn.lastActivity) == lastActivity {
			break
		}
	}
	conn.upstream.Close()
	if conn.clientAddr != nil {
		socketClosures <- conn.clientAddr.String()
	} else {
		socketClosures <- ""
	}
}

func udpCopyFromUpstream(downstream net.PacketConn, conn *udpConnection) {
	rawConn, err := conn.upstream.SyscallConn()
	if err != nil {
		conn.logger.Error("failed to retrieve raw connection from upstream socket", "error", err)
		return
	}

	var syscallErr error

	err = rawConn.Read(func(fd uintptr) bool {
		buf := GetBuffer()
		defer PutBuffer(buf)

		for {
			n, _, serr := syscall.Recvfrom(int(fd), buf, syscall.MSG_DONTWAIT)
			if errors.Is(serr, syscall.EWOULDBLOCK) {
				return false
			}
			if serr != nil {
				syscallErr = serr
				return true
			}
			if n == 0 {
				return true
			}

			atomic.AddInt64(conn.lastActivity, 1)

			if _, serr := downstream.WriteTo(buf[:n], conn.downstreamAddr); serr != nil {
				syscallErr = serr
				return true
			}
		}
	})

	if err == nil {
		err = syscallErr
	}
	if err != nil {
		conn.logger.Debug("failed to read from upstream", "error", err)
	}
}

func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Addr, logger *slog.Logger,
	connMap map[string]*udpConnection, socketClosures chan<- string) (*udpConnection, error) {
	connKey := ""
	if saddr != nil {
		connKey = saddr.String()
	}
	if conn := connMap[connKey]; conn != nil {
		atomic.AddInt64(conn.lastActivity, 1)
		return conn, nil
	}

	targetAddr := Opts.TargetAddr6
	if netip.MustParseAddr(downstreamAddr.String()).Is4() {
		targetAddr = Opts.TargetAddr4
	}

	logger = logger.With(slog.String("downstreamAddr", downstreamAddr.String()), slog.String("targetAddr", targetAddr.String()))
	dialer := net.Dialer{LocalAddr: saddr}
	if saddr != nil {
		logger = logger.With(slog.String("clientAddr", saddr.String()))
		dialer.Control = DialUpstreamControl(saddr.(*net.UDPAddr).Port)
	}

	if Opts.Verbose > 1 {
		logger.Debug("new connection")
	}

	conn, err := dialer.Dial("udp", targetAddr.String())
	if err != nil {
		logger.Debug("failed to connect to upstream", "error", err)
		return nil, err
	}

	udpConn := &udpConnection{upstream: conn.(*net.UDPConn),
		logger:         logger,
		lastActivity:   new(int64),
		downstreamAddr: downstreamAddr.(*net.UDPAddr)}
	if saddr != nil {
		udpConn.clientAddr = saddr.(*net.UDPAddr)
	}

	go udpCopyFromUpstream(downstream, udpConn)
	go udpCloseAfterInactivity(udpConn, socketClosures)

	connMap[connKey] = udpConn
	return udpConn, nil
}

func UDPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) {
	ctx := context.Background()
	ln, err := listenConfig.ListenPacket(ctx, "udp", Opts.ListenAddr.String())
	if err != nil {
		logger.Error("failed to bind listener", "error", err)
		errors <- err
		return
	}

	logger.Info("listening")

	socketClosures := make(chan string, 1024)
	connectionMap := make(map[string]*udpConnection)

	buffer := GetBuffer()
	defer PutBuffer(buffer)

	for {
		n, remoteAddr, err := ln.ReadFrom(buffer)
		if err != nil {
			logger.Error("failed to read from socket", "error", err)
			continue
		}

		if !CheckOriginAllowed(remoteAddr.(*net.UDPAddr).IP) {
			logger.Debug("packet origin not in allowed subnets", slog.String("remoteAddr", remoteAddr.String()))
			continue
		}

		saddr, _, restBytes, err := PROXYReadRemoteAddr(buffer[:n], UDP)
		if err != nil {
			logger.Debug("failed to parse PROXY header", "error", err, slog.String("remoteAddr", remoteAddr.String()))
			continue
		}

		for {
			doneClosing := false
			select {
			case mapKey := <-socketClosures:
				delete(connectionMap, mapKey)
			default:
				doneClosing = true
			}
			if doneClosing {
				break
			}
		}

		conn, err := udpGetSocketFromMap(ln, remoteAddr, saddr, logger, connectionMap, socketClosures)
		if err != nil {
			continue
		}

		_, err = conn.upstream.Write(restBytes)
		if err != nil {
			conn.logger.Error("failed to write to upstream socket", "error", err)
		}
	}
}