diff --git a/main.go b/main.go index 44cb05eddd5b17c1645d21b0c9faa2d30c5dc121..a5fdc720aff0b89bcba1a3f0d70a37e80a5be2b0 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ package main import ( + "bufio" "bytes" "encoding/binary" "flag" @@ -12,6 +13,7 @@ import ( "io" "log" "net" + "os" "strings" "syscall" ) @@ -19,13 +21,17 @@ import ( var listenAddr string var targetAddr4 string var targetAddr6 string +var allowedSubnetsPath string var mark int +var allowedSubnets []*net.IPNet + func init() { flag.StringVar(&listenAddr, "l", "0.0.0.0:8443", "Adress the proxy listens on") flag.StringVar(&targetAddr4, "4", "0.0.0.0:443", "Address to which IPv4 TCP traffic will be forwarded to") flag.StringVar(&targetAddr6, "6", "[::]:443", "Address to which IPv6 TCP traffic will be forwarded to") flag.IntVar(&mark, "mark", 123, "The mark that will be set on outbound packets") + flag.StringVar(&allowedSubnetsPath, "allowed-subnets", "", "Path to a file that contains subnets of the proxy servers") } func readRemoteAddrPROXYv2(conn net.Conn, ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) { @@ -193,9 +199,28 @@ func copyData(dst net.Conn, src net.Conn, ch chan<- error) { ch <- err } +func checkOriginAllowed(conn net.Conn) bool { + if len(allowedSubnets) == 0 { + return true + } + + addr := conn.RemoteAddr().(*net.TCPAddr) + for _, ipNet := range allowedSubnets { + if ipNet.Contains(addr.IP) { + return true + } + } + return false +} + func handleConnection(conn net.Conn) { defer conn.Close() + if !checkOriginAllowed(conn) { + log.Printf("Disallowed connection from %s", conn.RemoteAddr().String()) + return + } + saddr, _, restBytes, err := readRemoteAddr(conn) if err != nil { log.Printf("Failed to parse PROXY data from %s: %s", conn.RemoteAddr().String(), err.Error()) @@ -248,9 +273,35 @@ func handleConnection(conn net.Conn) { } } +func loadAllowedSubnets() error { + file, err := os.Open(allowedSubnetsPath) + if err != nil { + return err + } + + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + _, ipNet, err := net.ParseCIDR(scanner.Text()) + if err != nil { + return err + } + allowedSubnets = append(allowedSubnets, ipNet) + } + + return nil +} + func main() { flag.Parse() + if allowedSubnetsPath != "" { + if err := loadAllowedSubnets(); err != nil { + log.Fatalf("Failed to load allowed subnets file: %s", err.Error()) + } + } + ln, err := net.Listen("tcp", listenAddr) if err != nil { log.Fatalf("Failed to bind to %s: %s\n", listenAddr, err.Error()) diff --git a/path-prefixes.txt b/path-prefixes.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b4144d70030f0562465a9881ecee825298ddefd --- /dev/null +++ b/path-prefixes.txt @@ -0,0 +1 @@ +205.220.224.0/21 \ No newline at end of file