diff --git a/main.go b/main.go index aeb57a10c646e1c4b0e8eec8ce802dc9afa62baa..aa67bc5731efe275b6b2c32c4085575856d8c061 100644 --- a/main.go +++ b/main.go @@ -17,9 +17,12 @@ import ( type options struct { Protocol string - ListenAddr string - TargetAddr4 string - TargetAddr6 string + ListenAddrStr string + TargetAddr4Str string + TargetAddr6Str string + ListenAddr netip.AddrPort + TargetAddr4 netip.AddrPort + TargetAddr6 netip.AddrPort Mark int Verbose int allowedSubnetsPath string @@ -34,9 +37,9 @@ var Opts options func init() { flag.StringVar(&Opts.Protocol, "p", "tcp", "Protocol that will be proxied: tcp, udp") - flag.StringVar(&Opts.ListenAddr, "l", "0.0.0.0:8443", "Address the proxy listens on") - flag.StringVar(&Opts.TargetAddr4, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") - flag.StringVar(&Opts.TargetAddr6, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") + flag.StringVar(&Opts.ListenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on") + flag.StringVar(&Opts.TargetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") + flag.StringVar(&Opts.TargetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") flag.IntVar(&Opts.Mark, "mark", 0, "The mark that will be set on outbound packets") flag.IntVar(&Opts.Verbose, "v", 0, `0 - no logging of individual connections 1 - log errors occurring in individual connections @@ -50,7 +53,7 @@ func init() { func listen(listenerNum int, errors chan<- error) { logger := Opts.Logger.With(slog.Int("listenerNum", listenerNum), - slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr)) + slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr.String())) listenConfig := net.ListenConfig{} if Opts.Listeners > 1 { @@ -126,20 +129,29 @@ func main() { os.Exit(1) } - if _, err := netip.ParseAddr(Opts.ListenAddr); err != nil { + var err error + if Opts.ListenAddr, err = netip.ParseAddrPort(Opts.ListenAddrStr); err != nil { Opts.Logger.Error("listen address is malformed", "error", err) os.Exit(1) } - if _, err := netip.ParseAddr(Opts.TargetAddr4); err != nil { + if Opts.TargetAddr4, err = netip.ParseAddrPort(Opts.TargetAddr4Str); err != nil { Opts.Logger.Error("ipv4 target address is malformed", "error", err) os.Exit(1) } + if !Opts.TargetAddr4.Addr().Is4() { + Opts.Logger.Error("ipv4 target address is not IPv4") + os.Exit(1) + } - if _, err := netip.ParseAddr(Opts.TargetAddr6); err != nil { + if Opts.TargetAddr6, err = netip.ParseAddrPort(Opts.TargetAddr6Str); err != nil { Opts.Logger.Error("ipv6 target address is malformed", "error", err) os.Exit(1) } + if !Opts.TargetAddr6.Addr().Is6() { + Opts.Logger.Error("ipv6 target address is not IPv6") + os.Exit(1) + } if Opts.udpCloseAfter < 0 { Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter)) diff --git a/tcp.go b/tcp.go index a0db342a05ce573a5f2060ed99738961625026b2..878cfde8a39c90d5e13505246f811e8aab1c20db 100644 --- a/tcp.go +++ b/tcp.go @@ -63,7 +63,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { if saddr != nil { clientAddr = saddr.String() } - logger = logger.With(slog.String("clientAddr", clientAddr), slog.String("targetAddr", targetAddr)) + logger = logger.With(slog.String("clientAddr", clientAddr), slog.String("targetAddr", targetAddr.String())) if Opts.Verbose > 1 { logger.Debug("successfully parsed PROXY header") } @@ -72,7 +72,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { if saddr != nil { dialer.Control = DialUpstreamControl(saddr.(*net.TCPAddr).Port) } - upstreamConn, err := dialer.Dial("tcp", targetAddr) + upstreamConn, err := dialer.Dial("tcp", targetAddr.String()) if err != nil { logger.Debug("failed to establish upstream connection", "error", err, slog.Bool("dropConnection", true)) return @@ -122,7 +122,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { func TCPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { ctx := context.Background() - ln, err := listenConfig.Listen(ctx, "tcp", Opts.ListenAddr) + ln, err := listenConfig.Listen(ctx, "tcp", Opts.ListenAddr.String()) if err != nil { logger.Error("failed to bind listener", "error", err) errors <- err diff --git a/udp.go b/udp.go index 4373ab1a02d4456586e85e167d3775ae85be12aa..ca1344100f2c54e916d4682f20e122e84dcad033 100644 --- a/udp.go +++ b/udp.go @@ -98,7 +98,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad targetAddr = Opts.TargetAddr4 } - logger = logger.With(slog.String("downstreamAddr", downstreamAddr.String()), slog.String("targetAddr", targetAddr)) + logger = logger.With(slog.String("downstreamAddr", downstreamAddr.String()), slog.String("targetAddr", targetAddr.String())) dialer := net.Dialer{LocalAddr: saddr} if saddr != nil { logger = logger.With(slog.String("clientAddr", saddr.String())) @@ -109,7 +109,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad logger.Debug("new connection") } - conn, err := dialer.Dial("udp", targetAddr) + conn, err := dialer.Dial("udp", targetAddr.String()) if err != nil { logger.Debug("failed to connect to upstream", "error", err) return nil, err @@ -132,7 +132,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad func UDPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { ctx := context.Background() - ln, err := listenConfig.ListenPacket(ctx, "udp", Opts.ListenAddr) + ln, err := listenConfig.ListenPacket(ctx, "udp", Opts.ListenAddr.String()) if err != nil { logger.Error("failed to bind listener", "error", err) errors <- err