Skip to content
Snippets Groups Projects
proxyprotocol.go 4.32 KiB
Newer Older
  • Learn to ignore specific revisions
  • Konrad Zemek's avatar
    Konrad Zemek committed
    // Copyright 2019 Path Network, Inc. All rights reserved.
    // Use of this source code is governed by a BSD-style
    // license that can be found in the LICENSE file.
    
    package main
    
    import (
    	"bytes"
    	"encoding/binary"
    	"fmt"
    	"net"
    	"strings"
    )
    
    func readRemoteAddrPROXYv2(ctrlBuf []byte, protocol Protocol) (net.Addr, net.Addr, []byte, error) {
    	if (ctrlBuf[12] >> 4) != 2 {
    		return nil, nil, nil, fmt.Errorf("unknown protocol version %d", ctrlBuf[12]>>4)
    	}
    
    	if ctrlBuf[12]&0xF > 1 {
    		return nil, nil, nil, fmt.Errorf("unknown command %d", ctrlBuf[12]&0xF)
    	}
    
    	if ctrlBuf[12]&0xF == 1 && ((protocol == TCP && ctrlBuf[13] != 0x11 && ctrlBuf[13] != 0x21) ||
    		(protocol == UDP && ctrlBuf[13] != 0x12 && ctrlBuf[13] != 0x22)) {
    		return nil, nil, nil, fmt.Errorf("invalid family/protocol %d/%d", ctrlBuf[13]>>4, ctrlBuf[13]&0xF)
    	}
    
    	var dataLen uint16
    	reader := bytes.NewReader(ctrlBuf[14:16])
    	if err := binary.Read(reader, binary.BigEndian, &dataLen); err != nil {
    		return nil, nil, nil, fmt.Errorf("failed to decode address data length: %s", err.Error())
    	}
    
    	if len(ctrlBuf) < 16+int(dataLen) {
    		return nil, nil, nil, fmt.Errorf("incomplete PROXY header")
    	}
    
    	if ctrlBuf[12]&0xF == 0 { // LOCAL
    		return nil, nil, ctrlBuf[16+dataLen:], nil
    	}
    
    	var sport, dport uint16
    	if ctrlBuf[13]>>4 == 0x1 { // IPv4
    		reader = bytes.NewReader(ctrlBuf[24:])
    	} else {
    		reader = bytes.NewReader(ctrlBuf[48:])
    	}
    	if err := binary.Read(reader, binary.BigEndian, &sport); err != nil {
    		return nil, nil, nil, fmt.Errorf("failed to decode source port: %s", err.Error())
    	}
    	if err := binary.Read(reader, binary.BigEndian, &dport); err != nil {
    		return nil, nil, nil, fmt.Errorf("failed to decode destination port: %s", err.Error())
    	}
    
    	var srcIP, dstIP net.IP
    	if ctrlBuf[13]>>4 == 0x1 { // IPv4
    		srcIP = net.IPv4(ctrlBuf[16], ctrlBuf[17], ctrlBuf[18], ctrlBuf[19])
    		dstIP = net.IPv4(ctrlBuf[20], ctrlBuf[21], ctrlBuf[22], ctrlBuf[23])
    	} else {
    		srcIP = ctrlBuf[16:32]
    		dstIP = ctrlBuf[32:48]
    	}
    
    	if ctrlBuf[13]&0xF == 0x1 { // TCP
    		return &net.TCPAddr{IP: srcIP, Port: int(sport)},
    			&net.TCPAddr{IP: dstIP, Port: int(dport)},
    			ctrlBuf[16+dataLen:], nil
    	}
    
    	return &net.UDPAddr{IP: srcIP, Port: int(sport)},
    		&net.UDPAddr{IP: dstIP, Port: int(dport)},
    		ctrlBuf[16+dataLen:], nil
    }
    
    func readRemoteAddrPROXYv1(ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) {
    	str := string(ctrlBuf)
    	if idx := strings.Index(str, "\r\n"); idx >= 0 {
    		var headerProtocol, src, dst string
    		var sport, dport int
    		n, err := fmt.Sscanf(str, "PROXY %s", &headerProtocol)
    		if err != nil {
    			return nil, nil, nil, err
    		}
    		if n != 1 {
    			return nil, nil, nil, fmt.Errorf("failed to decode elements")
    		}
    		if headerProtocol == "UNKNOWN" {
    			return nil, nil, ctrlBuf[idx+2:], nil
    		}
    		if headerProtocol != "TCP4" && headerProtocol != "TCP6" {
    			return nil, nil, nil, fmt.Errorf("unknown protocol %s", headerProtocol)
    		}
    
    		n, err = fmt.Sscanf(str, "PROXY %s %s %s %d %d", &headerProtocol, &src, &dst, &sport, &dport)
    		if err != nil {
    			return nil, nil, nil, err
    		}
    		if n != 5 {
    			return nil, nil, nil, fmt.Errorf("failed to decode elements")
    		}
    		srcIP := net.ParseIP(src)
    		if srcIP == nil {
    			return nil, nil, nil, fmt.Errorf("failed to parse source IP address %s", src)
    		}
    		dstIP := net.ParseIP(dst)
    		if dstIP == nil {
    			return nil, nil, nil, fmt.Errorf("failed to parse destination IP address %s", dst)
    		}
    		return &net.TCPAddr{IP: srcIP, Port: sport},
    			&net.TCPAddr{IP: dstIP, Port: dport},
    			ctrlBuf[idx+2:], nil
    	}
    
    	return nil, nil, nil, fmt.Errorf("did not find \\r\\n in first data segment")
    }
    
    func PROXYReadRemoteAddr(buf []byte, protocol Protocol) (net.Addr, net.Addr, []byte, error) {
    	if len(buf) >= 16 && bytes.Equal(buf[:12],
    		[]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}) {
    		saddr, daddr, rest, err := readRemoteAddrPROXYv2(buf, protocol)
    		if err != nil {
    			return nil, nil, nil, fmt.Errorf("failed to parse PROXY v2 header: %s", err.Error())
    		}
    		return saddr, daddr, rest, err
    	}
    
    	// PROXYv1 only works with TCP
    	if protocol == TCP && len(buf) >= 8 && bytes.Equal(buf[:5], []byte("PROXY")) {
    		saddr, daddr, rest, err := readRemoteAddrPROXYv1(buf)
    		if err != nil {
    			return nil, nil, nil, fmt.Errorf("failed to parse PROXY v1 header: %s", err.Error())
    		}
    		return saddr, daddr, rest, err
    	}
    
    	return nil, nil, nil, fmt.Errorf("PROXY header missing")
    }