From 68755bf0ee0f81a3c38bcc85ac04997114a9ef72 Mon Sep 17 00:00:00 2001
From: Stefan Majer <stefan.majer@f-i-ts.de>
Date: Thu, 19 Oct 2023 13:54:19 +0200
Subject: [PATCH] Use net/netip to check for ipv4 (#25)

---
 main.go  | 16 ++++++++++++++++
 tcp.go   |  5 +++--
 udp.go   |  3 ++-
 utils.go |  9 ---------
 4 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/main.go b/main.go
index 9062dfa..aeb57a1 100644
--- a/main.go
+++ b/main.go
@@ -9,6 +9,7 @@ import (
 	"flag"
 	"log/slog"
 	"net"
+	"net/netip"
 	"os"
 	"syscall"
 	"time"
@@ -125,6 +126,21 @@ func main() {
 		os.Exit(1)
 	}
 
+	if _, err := netip.ParseAddr(Opts.ListenAddr); err != nil {
+		Opts.Logger.Error("listen address is malformed", "error", err)
+		os.Exit(1)
+	}
+
+	if _, err := netip.ParseAddr(Opts.TargetAddr4); err != nil {
+		Opts.Logger.Error("ipv4 target address is malformed", "error", err)
+		os.Exit(1)
+	}
+
+	if _, err := netip.ParseAddr(Opts.TargetAddr6); err != nil {
+		Opts.Logger.Error("ipv6 target address is malformed", "error", err)
+		os.Exit(1)
+	}
+
 	if Opts.udpCloseAfter < 0 {
 		Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter))
 		os.Exit(1)
diff --git a/tcp.go b/tcp.go
index f8eb1ee..a0db342 100644
--- a/tcp.go
+++ b/tcp.go
@@ -9,6 +9,7 @@ import (
 	"io"
 	"log/slog"
 	"net"
+	"net/netip"
 )
 
 func tcpCopyData(dst net.Conn, src net.Conn, ch chan<- error) {
@@ -51,10 +52,10 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) {
 
 	targetAddr := Opts.TargetAddr6
 	if saddr == nil {
-		if AddrVersion(conn.RemoteAddr()) == 4 {
+		if netip.MustParseAddr(conn.RemoteAddr().String()).Is4() {
 			targetAddr = Opts.TargetAddr4
 		}
-	} else if AddrVersion(saddr) == 4 {
+	} else if netip.MustParseAddr(saddr.String()).Is4() {
 		targetAddr = Opts.TargetAddr4
 	}
 
diff --git a/udp.go b/udp.go
index 5eb2054..4373ab1 100644
--- a/udp.go
+++ b/udp.go
@@ -9,6 +9,7 @@ import (
 	"errors"
 	"log/slog"
 	"net"
+	"net/netip"
 	"sync/atomic"
 	"syscall"
 	"time"
@@ -93,7 +94,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad
 	}
 
 	targetAddr := Opts.TargetAddr6
-	if AddrVersion(downstreamAddr) == 4 {
+	if netip.MustParseAddr(downstreamAddr.String()).Is4() {
 		targetAddr = Opts.TargetAddr4
 	}
 
diff --git a/utils.go b/utils.go
index 02d7490..432c152 100644
--- a/utils.go
+++ b/utils.go
@@ -7,7 +7,6 @@ package main
 import (
 	"fmt"
 	"net"
-	"strings"
 	"syscall"
 )
 
@@ -87,11 +86,3 @@ func DialUpstreamControl(sport int) func(string, string, syscall.RawConn) error
 		return syscallErr
 	}
 }
-
-func AddrVersion(addr net.Addr) int {
-	// poor man's ipv6 check - golang makes it unnecessarily hard
-	if strings.ContainsRune(addr.String(), '.') {
-		return 4
-	}
-	return 6
-}
-- 
GitLab