diff --git a/README.md b/README.md index 24ee834ad8a6f0194b3e5a3022029cc21605f5a8..c5c2c3274d2972501a6c72b984f3f7da27b35bad 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,8 @@ Usage of ./go-mmproxy: Path to a file that contains allowed subnets of the proxy servers -l string Adress the proxy listens on (default "0.0.0.0:8443") + -listeners int + Number of listener sockets that will be opened for the listen address (default 1) -mark int The mark that will be set on outbound packets -v int @@ -57,6 +59,7 @@ Usage of ./go-mmproxy: 1 - log errors occuring in individual connections 2 - log all state changes of individual connections + ``` Example invocation: diff --git a/main.go b/main.go index 187664b9eddce055e1e471061e297249f4ee86b3..194bc00614b07cb7874fec3230933b15b5330f00 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ package main import ( "bufio" "bytes" + "context" "encoding/binary" "flag" "fmt" @@ -26,6 +27,7 @@ var targetAddr6 string var allowedSubnetsPath string var mark int var verbose int +var listeners int var allowedSubnets []*net.IPNet var logger *zap.Logger @@ -40,6 +42,7 @@ func init() { flag.IntVar(&verbose, "v", 0, `0 - no logging of individual connections 1 - log errors occuring in individual connections 2 - log all state changes of individual connections`) + flag.IntVar(&listeners, "listeners", 1, "Number of listener sockets that will be opened for the listen address") } func readRemoteAddrPROXYv2(conn net.Conn, ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) { @@ -233,9 +236,9 @@ func checkOriginAllowed(conn net.Conn) bool { return false } -func handleConnection(conn net.Conn) { +func handleConnection(conn net.Conn, listenLog *zap.Logger) { defer conn.Close() - connLog := logger.With(zap.String("remoteAddr", conn.RemoteAddr().String()), + connLog := listenLog.With(zap.String("remoteAddr", conn.RemoteAddr().String()), zap.String("localAddr", conn.LocalAddr().String())) if !checkOriginAllowed(conn) { @@ -309,6 +312,43 @@ func handleConnection(conn net.Conn) { } } +func listen(listenerNum int, errors chan<- error) { + listenLog := logger.With(zap.Int("listenerNum", listenerNum)) + + listenConfig := net.ListenConfig{} + if listeners > 1 { + listenConfig.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + soReusePort := 15 + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, soReusePort, 1); err != nil { + listenLog.Warn("failed to set SO_REUSEPORT - only one listener setup will succeed") + } + }) + } + } + + ctx := context.Background() + ln, err := listenConfig.Listen(ctx, "tcp", listenAddr) + if err != nil { + listenLog.Error("failed to bind listener", zap.String("listenAddr", listenAddr), zap.Error(err)) + errors <- err + return + } + + listenLog.Info("listening", zap.String("listenAddr", listenAddr)) + + for { + conn, err := ln.Accept() + if err != nil { + listenLog.Error("failed to accept new connection", zap.Error(err)) + errors <- err + return + } + + go handleConnection(conn, listenLog) + } +} + func loadAllowedSubnets() error { file, err := os.Open(allowedSubnetsPath) if err != nil { @@ -345,29 +385,26 @@ func initLogger() error { func main() { flag.Parse() - if err := initLogger(); err != nil { log.Fatalf("Failed to initialize logging: %s", err.Error()) } defer logger.Sync() + if listeners <= 0 { + logger.Fatal("--listeners has to be >= 1") + } + if allowedSubnetsPath != "" { if err := loadAllowedSubnets(); err != nil { logger.Fatal("failed to load allowed subnets file", zap.String("path", allowedSubnetsPath), zap.Error(err)) } } - ln, err := net.Listen("tcp", listenAddr) - if err != nil { - logger.Fatal("failed to bind listener", zap.String("listenAddr", listenAddr), zap.Error(err)) + listenErrors := make(chan error, listeners) + for i := 0; i < listeners; i++ { + go listen(i, listenErrors) } - - for { - conn, err := ln.Accept() - if err != nil { - logger.Fatal("failed to accept new connection", zap.Error(err)) - } - - go handleConnection(conn) + for i := 0; i < listeners; i++ { + <-listenErrors } }