Skip to content
Snippets Groups Projects
Unverified Commit 68755bf0 authored by Stefan Majer's avatar Stefan Majer Committed by GitHub
Browse files

Use net/netip to check for ipv4 (#25)

parent 0d53f620
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"flag" "flag"
"log/slog" "log/slog"
"net" "net"
"net/netip"
"os" "os"
"syscall" "syscall"
"time" "time"
...@@ -125,6 +126,21 @@ func main() { ...@@ -125,6 +126,21 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if _, err := netip.ParseAddr(Opts.ListenAddr); err != nil {
Opts.Logger.Error("listen address is malformed", "error", err)
os.Exit(1)
}
if _, err := netip.ParseAddr(Opts.TargetAddr4); err != nil {
Opts.Logger.Error("ipv4 target address is malformed", "error", err)
os.Exit(1)
}
if _, err := netip.ParseAddr(Opts.TargetAddr6); err != nil {
Opts.Logger.Error("ipv6 target address is malformed", "error", err)
os.Exit(1)
}
if Opts.udpCloseAfter < 0 { if Opts.udpCloseAfter < 0 {
Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter)) Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter))
os.Exit(1) os.Exit(1)
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"net" "net"
"net/netip"
) )
func tcpCopyData(dst net.Conn, src net.Conn, ch chan<- error) { func tcpCopyData(dst net.Conn, src net.Conn, ch chan<- error) {
...@@ -51,10 +52,10 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { ...@@ -51,10 +52,10 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) {
targetAddr := Opts.TargetAddr6 targetAddr := Opts.TargetAddr6
if saddr == nil { if saddr == nil {
if AddrVersion(conn.RemoteAddr()) == 4 { if netip.MustParseAddr(conn.RemoteAddr().String()).Is4() {
targetAddr = Opts.TargetAddr4 targetAddr = Opts.TargetAddr4
} }
} else if AddrVersion(saddr) == 4 { } else if netip.MustParseAddr(saddr.String()).Is4() {
targetAddr = Opts.TargetAddr4 targetAddr = Opts.TargetAddr4
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"net" "net"
"net/netip"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
...@@ -93,7 +94,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad ...@@ -93,7 +94,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad
} }
targetAddr := Opts.TargetAddr6 targetAddr := Opts.TargetAddr6
if AddrVersion(downstreamAddr) == 4 { if netip.MustParseAddr(downstreamAddr.String()).Is4() {
targetAddr = Opts.TargetAddr4 targetAddr = Opts.TargetAddr4
} }
......
...@@ -7,7 +7,6 @@ package main ...@@ -7,7 +7,6 @@ package main
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"syscall" "syscall"
) )
...@@ -87,11 +86,3 @@ func DialUpstreamControl(sport int) func(string, string, syscall.RawConn) error ...@@ -87,11 +86,3 @@ func DialUpstreamControl(sport int) func(string, string, syscall.RawConn) error
return syscallErr return syscallErr
} }
} }
func AddrVersion(addr net.Addr) int {
// poor man's ipv6 check - golang makes it unnecessarily hard
if strings.ContainsRune(addr.String(), '.') {
return 4
}
return 6
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment