From 1d587405cece9fc4b164030fcae613c032fd6fd1 Mon Sep 17 00:00:00 2001
From: Konrad Zemek <konrad.zemek@gmail.com>
Date: Tue, 28 May 2019 01:06:49 +0200
Subject: [PATCH] Implement a mmproxy replacement in Go.

---
 LICENSE |  27 ++++++
 main.go | 267 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 294 insertions(+)
 create mode 100644 LICENSE
 create mode 100644 main.go

diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..dc9f14e
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2019 Path Network, Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+   * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+   * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+   * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..44cb05e
--- /dev/null
+++ b/main.go
@@ -0,0 +1,267 @@
+// Copyright 2019 Path Network, Inc. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+	"bytes"
+	"encoding/binary"
+	"flag"
+	"fmt"
+	"io"
+	"log"
+	"net"
+	"strings"
+	"syscall"
+)
+
+var listenAddr string
+var targetAddr4 string
+var targetAddr6 string
+var mark int
+
+func init() {
+	flag.StringVar(&listenAddr, "l", "0.0.0.0:8443", "Adress the proxy listens on")
+	flag.StringVar(&targetAddr4, "4", "0.0.0.0:443", "Address to which IPv4 TCP traffic will be forwarded to")
+	flag.StringVar(&targetAddr6, "6", "[::]:443", "Address to which IPv6 TCP traffic will be forwarded to")
+	flag.IntVar(&mark, "mark", 123, "The mark that will be set on outbound packets")
+}
+
+func readRemoteAddrPROXYv2(conn net.Conn, ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) {
+	if (ctrlBuf[12] >> 4) != 2 {
+		return nil, nil, nil, fmt.Errorf("unknown protocol version %d", ctrlBuf[12]>>4)
+	}
+
+	if ctrlBuf[12]&0xFF > 1 {
+		return nil, nil, nil, fmt.Errorf("unknown command %d", ctrlBuf[12]&0xFF)
+	}
+
+	if ctrlBuf[12]&0xFF == 1 && ctrlBuf[13] != 0x11 && ctrlBuf[13] != 0x21 {
+		return nil, nil, nil, fmt.Errorf("invalid family/protocol %d/%d", ctrlBuf[13]>>4, ctrlBuf[13]&0xFF)
+	}
+
+	var dataLen uint16
+	reader := bytes.NewReader(ctrlBuf[14:16])
+	if err := binary.Read(reader, binary.BigEndian, &dataLen); err != nil {
+		return nil, nil, nil, fmt.Errorf("failed to decode address data length: %s", err.Error())
+	}
+
+	if ctrlBuf[12]&0xFF == 1 { // LOCAL
+		return conn.RemoteAddr(), conn.LocalAddr(), ctrlBuf[16+dataLen:], nil
+	}
+
+	var sport, dport uint16
+	if ctrlBuf[13] == 0x11 { // IPv4
+		reader = bytes.NewReader(ctrlBuf[24:])
+	} else {
+		reader = bytes.NewReader(ctrlBuf[48:])
+	}
+	if err := binary.Read(reader, binary.BigEndian, &sport); err != nil {
+		return nil, nil, nil, fmt.Errorf("failed to decode source TCP port: %s", err.Error())
+	}
+	if err := binary.Read(reader, binary.BigEndian, &dport); err != nil {
+		return nil, nil, nil, fmt.Errorf("failed to decode destination TCP port: %s", err.Error())
+	}
+
+	if ctrlBuf[13] == 0x11 { // TCP over IPv4
+		srcIP := net.IPv4(ctrlBuf[16], ctrlBuf[17], ctrlBuf[18], ctrlBuf[19])
+		dstIP := net.IPv4(ctrlBuf[20], ctrlBuf[21], ctrlBuf[22], ctrlBuf[23])
+		return &net.TCPAddr{IP: srcIP, Port: int(sport)}, &net.TCPAddr{IP: dstIP, Port: int(dport)}, ctrlBuf[16+dataLen:], nil
+	}
+
+	return &net.TCPAddr{IP: ctrlBuf[16:32], Port: int(sport)}, &net.TCPAddr{IP: ctrlBuf[32:48], Port: int(dport)}, ctrlBuf[16+dataLen:], nil
+}
+
+func readRemoteAddrPROXYv1(conn net.Conn, ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) {
+	str := string(ctrlBuf)
+	if idx := strings.Index(str, "\r\n"); idx >= 0 {
+		var protocol, src, dst string
+		var sport, dport int
+		n, err := fmt.Sscanf(str, "PROXY %s", &protocol)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+		if n != 1 {
+			return nil, nil, nil, fmt.Errorf("failed to decode elements")
+		}
+		if protocol == "UNKNOWN" {
+			return conn.RemoteAddr(), conn.LocalAddr(), ctrlBuf[idx+2:], nil
+		}
+		if protocol != "TCP4" && protocol != "TCP6" {
+			return nil, nil, nil, fmt.Errorf("unknown protocol %s", protocol)
+		}
+
+		n, err = fmt.Sscanf(str, "PROXY %s %s %s %d %d", &protocol, &src, &dst, &sport, &dport)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+		if n != 5 {
+			return nil, nil, nil, fmt.Errorf("failed to decode elements")
+		}
+		srcIP := net.ParseIP(src)
+		if srcIP == nil {
+			return nil, nil, nil, fmt.Errorf("failed to parse source IP address %s", src)
+		}
+		dstIP := net.ParseIP(dst)
+		if dstIP == nil {
+			return nil, nil, nil, fmt.Errorf("failed to parse destination IP address %s", dst)
+		}
+		return &net.TCPAddr{IP: srcIP, Port: sport}, &net.TCPAddr{IP: dstIP, Port: dport}, ctrlBuf[idx+2:], nil
+	}
+
+	return nil, nil, nil, fmt.Errorf("did not find \\r\\n in first data segment")
+}
+
+func readRemoteAddr(conn net.Conn) (net.Addr, net.Addr, []byte, error) {
+	buf := make([]byte, 108)
+	n, err := conn.Read(buf)
+	if err != nil {
+		return nil, nil, nil, fmt.Errorf("failed to read header: %s", err.Error())
+	}
+
+	if n >= 16 && bytes.Equal(buf[:13], []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}) {
+		saddr, daddr, rest, err := readRemoteAddrPROXYv2(conn, buf[:n])
+		if err != nil {
+			return nil, nil, nil, fmt.Errorf("failed to parse PROXY v2 header: %s", err.Error())
+		}
+		return saddr, daddr, rest, err
+	}
+
+	if n >= 8 && bytes.Equal(buf[:5], []byte("PROXY")) {
+		saddr, daddr, rest, err := readRemoteAddrPROXYv1(conn, buf[:n])
+		if err != nil {
+			return nil, nil, nil, fmt.Errorf("failed to parse PROXY v1 header: %s", err.Error())
+		}
+		return saddr, daddr, rest, err
+	}
+
+	return nil, nil, nil, fmt.Errorf("PROXY header missing")
+}
+
+func dialUpstreamControl(sport int) func(string, string, syscall.RawConn) error {
+	return func(network, address string, c syscall.RawConn) error {
+		var syscallErr error
+		err := c.Control(func(fd uintptr) {
+			syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, syscall.TCP_SYNCNT, 2)
+			if syscallErr != nil {
+				syscallErr = fmt.Errorf("setsockopt(IPPROTO_TCP, TCP_SYNCTNT, 2): %s", syscallErr.Error())
+				return
+			}
+
+			syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TRANSPARENT, 1)
+			if syscallErr != nil {
+				syscallErr = fmt.Errorf("setsockopt(IPPROTO_IP, IP_TRANSPARENT, 1): %s", syscallErr.Error())
+				return
+			}
+
+			syscallErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+			if syscallErr != nil {
+				syscallErr = fmt.Errorf("setsockopt(SOL_SOCKET, SO_REUSEADDR, 1): %s", syscallErr.Error())
+				return
+			}
+
+			if sport == 0 {
+				ipBindAddressNoPort := 24
+				syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ipBindAddressNoPort, 1)
+			}
+
+			syscallErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
+			if syscallErr != nil {
+				syscallErr = fmt.Errorf("setsockopt(SOL_SOCK, SO_MARK, %d): %s", mark, syscallErr.Error())
+				return
+			}
+
+			if network == "tcp6" {
+				syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IPV6_V6ONLY, 0)
+				if syscallErr != nil {
+					syscallErr = fmt.Errorf("setsockopt(IPPROTO_IP, IPV6_ONLY, 0): %s", syscallErr.Error())
+					return
+				}
+			}
+		})
+
+		if err != nil {
+			return err
+		}
+		return syscallErr
+	}
+}
+
+func copyData(dst net.Conn, src net.Conn, ch chan<- error) {
+	_, err := io.Copy(dst, src)
+	ch <- err
+}
+
+func handleConnection(conn net.Conn) {
+	defer conn.Close()
+
+	saddr, _, restBytes, err := readRemoteAddr(conn)
+	if err != nil {
+		log.Printf("Failed to parse PROXY data from %s: %s", conn.RemoteAddr().String(), err.Error())
+		return
+	}
+
+	targetAddr := targetAddr6
+	if strings.ContainsRune(saddr.String(), '.') { // poor man's ipv6 check - golang makes it unnecessarily hard
+		targetAddr = targetAddr4
+	}
+
+	dialer := net.Dialer{LocalAddr: saddr, Control: dialUpstreamControl(saddr.(*net.TCPAddr).Port)}
+	upstreamConn, err := dialer.Dial("tcp", targetAddr)
+	if err != nil {
+		log.Printf("Failed to establish upstream connection %s -> %s (PROXY %s -> %s): %s",
+			conn.RemoteAddr().String(), conn.LocalAddr().String(), saddr.String(), targetAddr, err.Error())
+		return
+	}
+
+	if err := conn.(*net.TCPConn).SetNoDelay(true); err != nil {
+		log.Printf("Failed to set nodelay on upstream connection %s -> %s (PROXY %s -> %s): %s",
+			conn.RemoteAddr().String(), conn.LocalAddr().String(), saddr.String(), targetAddr, err.Error())
+	}
+
+	if err := upstreamConn.(*net.TCPConn).SetNoDelay(true); err != nil {
+		log.Printf("Failed to set nodelay on upstream connection %s -> %s (PROXY %s -> %s): %s",
+			conn.RemoteAddr().String(), conn.LocalAddr().String(), saddr.String(), targetAddr, err.Error())
+	}
+
+	defer upstreamConn.Close()
+
+	for len(restBytes) > 0 {
+		n, err := conn.Write(restBytes)
+		if err != nil {
+			log.Printf("Failed to write data to upstream connection %s -> %s (PROXY %s -> %s): %s",
+				conn.RemoteAddr().String(), conn.LocalAddr().String(), saddr.String(), targetAddr, err.Error())
+			return
+		}
+		restBytes = restBytes[n:]
+	}
+
+	outErr := make(chan error, 2)
+	go copyData(upstreamConn, conn, outErr)
+	go copyData(conn, upstreamConn, outErr)
+
+	err = <-outErr
+	if err != nil {
+		log.Printf("Connection %s -> %s (PROXY %s -> %s): %s",
+			conn.RemoteAddr().String(), conn.LocalAddr().String(), saddr.String(), targetAddr, err.Error())
+	}
+}
+
+func main() {
+	flag.Parse()
+
+	ln, err := net.Listen("tcp", listenAddr)
+	if err != nil {
+		log.Fatalf("Failed to bind to %s: %s\n", listenAddr, err.Error())
+	}
+
+	for {
+		conn, err := ln.Accept()
+		if err != nil {
+			log.Fatalf("Failed to accept new connection: %s\n", err.Error())
+		}
+
+		go handleConnection(conn)
+	}
+}
-- 
GitLab