Skip to content
Snippets Groups Projects
udp.go 4.51 KiB
Newer Older
  • Learn to ignore specific revisions
  • Konrad Zemek's avatar
    Konrad Zemek committed
    // Copyright 2019 Path Network, Inc. All rights reserved.
    // Use of this source code is governed by a BSD-style
    // license that can be found in the LICENSE file.
    
    package main
    
    import (
    	"context"
    
    	"errors"
    	"log/slog"
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	"net"
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	"sync/atomic"
    	"syscall"
    	"time"
    )
    
    type udpConnection struct {
    	lastActivity   *int64
    	clientAddr     *net.UDPAddr
    	downstreamAddr *net.UDPAddr
    	upstream       *net.UDPConn
    
    	logger         *slog.Logger
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    }
    
    func udpCloseAfterInactivity(conn *udpConnection, socketClosures chan<- string) {
    	for {
    		lastActivity := atomic.LoadInt64(conn.lastActivity)
    		<-time.After(Opts.UDPCloseAfter)
    		if atomic.LoadInt64(conn.lastActivity) == lastActivity {
    			break
    		}
    	}
    	conn.upstream.Close()
    
    	if conn.clientAddr != nil {
    		socketClosures <- conn.clientAddr.String()
    	} else {
    		socketClosures <- ""
    	}
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    }
    
    func udpCopyFromUpstream(downstream net.PacketConn, conn *udpConnection) {
    	rawConn, err := conn.upstream.SyscallConn()
    	if err != nil {
    
    		conn.logger.Error("failed to retrieve raw connection from upstream socket", "error", err)
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    		return
    	}
    
    	var syscallErr error
    
    	err = rawConn.Read(func(fd uintptr) bool {
    		buf := GetBuffer()
    		defer PutBuffer(buf)
    
    		for {
    			n, _, serr := syscall.Recvfrom(int(fd), buf, syscall.MSG_DONTWAIT)
    
    			if errors.Is(serr, syscall.EWOULDBLOCK) {
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    				return false
    			}
    			if serr != nil {
    				syscallErr = serr
    				return true
    			}
    			if n == 0 {
    				return true
    			}
    
    			atomic.AddInt64(conn.lastActivity, 1)
    
    			if _, serr := downstream.WriteTo(buf[:n], conn.downstreamAddr); serr != nil {
    				syscallErr = serr
    				return true
    			}
    		}
    	})
    
    	if err == nil {
    		err = syscallErr
    	}
    	if err != nil {
    
    		conn.logger.Debug("failed to read from upstream", "error", err)
    
    func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Addr, logger *slog.Logger,
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	connMap map[string]*udpConnection, socketClosures chan<- string) (*udpConnection, error) {
    	connKey := ""
    	if saddr != nil {
    		connKey = saddr.String()
    	}
    
    	if conn := connMap[connKey]; conn != nil {
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    		atomic.AddInt64(conn.lastActivity, 1)
    		return conn, nil
    	}
    
    	targetAddr := Opts.TargetAddr6
    
    	if netip.MustParseAddr(downstreamAddr.String()).Is4() {
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    		targetAddr = Opts.TargetAddr4
    	}
    
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	logger = logger.With(slog.String("downstreamAddr", downstreamAddr.String()), slog.String("targetAddr", targetAddr.String()))
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	dialer := net.Dialer{LocalAddr: saddr}
    	if saddr != nil {
    
    		logger = logger.With(slog.String("clientAddr", saddr.String()))
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    		dialer.Control = DialUpstreamControl(saddr.(*net.UDPAddr).Port)
    	}
    
    	if Opts.Verbose > 1 {
    		logger.Debug("new connection")
    	}
    
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	conn, err := dialer.Dial("udp", targetAddr.String())
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	if err != nil {
    
    		logger.Debug("failed to connect to upstream", "error", err)
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    		return nil, err
    	}
    
    	udpConn := &udpConnection{upstream: conn.(*net.UDPConn),
    		logger:         logger,
    		lastActivity:   new(int64),
    		downstreamAddr: downstreamAddr.(*net.UDPAddr)}
    
    	if saddr != nil {
    		udpConn.clientAddr = saddr.(*net.UDPAddr)
    	}
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    
    	go udpCopyFromUpstream(downstream, udpConn)
    	go udpCloseAfterInactivity(udpConn, socketClosures)
    
    	connMap[connKey] = udpConn
    	return udpConn, nil
    }
    
    
    func UDPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) {
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	ctx := context.Background()
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	ln, err := listenConfig.ListenPacket(ctx, "udp", Opts.ListenAddr.String())
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    	if err != nil {
    
    		logger.Error("failed to bind listener", "error", err)
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    		errors <- err
    		return
    	}
    
    	logger.Info("listening")
    
    	socketClosures := make(chan string, 1024)
    	connectionMap := make(map[string]*udpConnection)
    
    	buffer := GetBuffer()
    	defer PutBuffer(buffer)
    
    	for {
    		n, remoteAddr, err := ln.ReadFrom(buffer)
    		if err != nil {
    
    			logger.Error("failed to read from socket", "error", err)
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    			continue
    		}
    
    		if !CheckOriginAllowed(remoteAddr.(*net.UDPAddr).IP) {
    
    			logger.Debug("packet origin not in allowed subnets", slog.String("remoteAddr", remoteAddr.String()))
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    			continue
    		}
    
    		saddr, _, restBytes, err := PROXYReadRemoteAddr(buffer[:n], UDP)
    		if err != nil {
    
    			logger.Debug("failed to parse PROXY header", "error", err, slog.String("remoteAddr", remoteAddr.String()))
    
    Konrad Zemek's avatar
    Konrad Zemek committed
    			continue
    		}
    
    		for {
    			doneClosing := false
    			select {
    			case mapKey := <-socketClosures:
    				delete(connectionMap, mapKey)
    			default:
    				doneClosing = true
    			}
    			if doneClosing {
    				break
    			}
    		}
    
    		conn, err := udpGetSocketFromMap(ln, remoteAddr, saddr, logger, connectionMap, socketClosures)
    		if err != nil {
    			continue
    		}
    
    		_, err = conn.upstream.Write(restBytes)
    		if err != nil {
    
    			conn.logger.Error("failed to write to upstream socket", "error", err)