From 7d0cb2a2adec493b8ad9d79ef35354c8e20f0213 Mon Sep 17 00:00:00 2001
From: Keith Randall <khr@golang.org>
Date: Mon, 21 Apr 2025 12:44:24 -0700
Subject: [PATCH] cmd/compile: constant fold 128-bit multiplies

The full 64x64->128 multiply comes up when using bits.Mul64.
The 64x64->64+overflow multiply comes up in unsafe.Slice when using
a constant length.

Change-Id: I298515162ca07d804b2d699d03bc957ca30a4ebc
Reviewed-on: https://go-review.googlesource.com/c/go/+/667175
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Keith Randall <khr@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
---
 .../compile/internal/ssa/_gen/generic.rules   |   4 +
 src/cmd/compile/internal/ssa/rewrite.go       |  11 ++
 .../compile/internal/ssa/rewritegeneric.go    | 108 ++++++++++++++++++
 test/codegen/mathbits.go                      |  16 ++-
 4 files changed, 138 insertions(+), 1 deletion(-)

diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules
index eb04d03e496..baa26133fe0 100644
--- a/src/cmd/compile/internal/ssa/_gen/generic.rules
+++ b/src/cmd/compile/internal/ssa/_gen/generic.rules
@@ -139,6 +139,10 @@
 (Mul64  (Const64 [c])  (Const64 [d]))  => (Const64 [c*d])
 (Mul32F (Const32F [c]) (Const32F [d])) && c*d == c*d => (Const32F [c*d])
 (Mul64F (Const64F [c]) (Const64F [d])) && c*d == c*d => (Const64F [c*d])
+(Mul32uhilo (Const32 [c]) (Const32 [d])) => (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).hi]) (Const32 <typ.UInt32> [bitsMulU32(c,d).lo]))
+(Mul64uhilo (Const64 [c]) (Const64 [d])) => (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).hi]) (Const64 <typ.UInt64> [bitsMulU64(c,d).lo]))
+(Mul32uover (Const32 [c]) (Const32 [d])) => (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU32(c,d).hi != 0]))
+(Mul64uover (Const64 [c]) (Const64 [d])) => (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU64(c,d).hi != 0]))
 
 (And8   (Const8 [c])   (Const8 [d]))   => (Const8  [c&d])
 (And16  (Const16 [c])  (Const16 [d]))  => (Const16 [c&d])
diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go
index ed79d515463..859814a2d76 100644
--- a/src/cmd/compile/internal/ssa/rewrite.go
+++ b/src/cmd/compile/internal/ssa/rewrite.go
@@ -2584,3 +2584,14 @@ func bitsAdd64(x, y, carry int64) (r struct{ sum, carry int64 }) {
 	r.sum, r.carry = int64(s), int64(c)
 	return
 }
+
+func bitsMulU64(x, y int64) (r struct{ hi, lo int64 }) {
+	hi, lo := bits.Mul64(uint64(x), uint64(y))
+	r.hi, r.lo = int64(hi), int64(lo)
+	return
+}
+func bitsMulU32(x, y int32) (r struct{ hi, lo int32 }) {
+	hi, lo := bits.Mul32(uint32(x), uint32(y))
+	r.hi, r.lo = int32(hi), int32(lo)
+	return
+}
diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go
index 4fdb22b868e..b8866cc5628 100644
--- a/src/cmd/compile/internal/ssa/rewritegeneric.go
+++ b/src/cmd/compile/internal/ssa/rewritegeneric.go
@@ -246,12 +246,16 @@ func rewriteValuegeneric(v *Value) bool {
 		return rewriteValuegeneric_OpMul32(v)
 	case OpMul32F:
 		return rewriteValuegeneric_OpMul32F(v)
+	case OpMul32uhilo:
+		return rewriteValuegeneric_OpMul32uhilo(v)
 	case OpMul32uover:
 		return rewriteValuegeneric_OpMul32uover(v)
 	case OpMul64:
 		return rewriteValuegeneric_OpMul64(v)
 	case OpMul64F:
 		return rewriteValuegeneric_OpMul64F(v)
+	case OpMul64uhilo:
+		return rewriteValuegeneric_OpMul64uhilo(v)
 	case OpMul64uover:
 		return rewriteValuegeneric_OpMul64uover(v)
 	case OpMul8:
@@ -18766,10 +18770,62 @@ func rewriteValuegeneric_OpMul32F(v *Value) bool {
 	}
 	return false
 }
+func rewriteValuegeneric_OpMul32uhilo(v *Value) bool {
+	v_1 := v.Args[1]
+	v_0 := v.Args[0]
+	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Mul32uhilo (Const32 [c]) (Const32 [d]))
+	// result: (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).hi]) (Const32 <typ.UInt32> [bitsMulU32(c,d).lo]))
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			if v_0.Op != OpConst32 {
+				continue
+			}
+			c := auxIntToInt32(v_0.AuxInt)
+			if v_1.Op != OpConst32 {
+				continue
+			}
+			d := auxIntToInt32(v_1.AuxInt)
+			v.reset(OpMakeTuple)
+			v0 := b.NewValue0(v.Pos, OpConst32, typ.UInt32)
+			v0.AuxInt = int32ToAuxInt(bitsMulU32(c, d).hi)
+			v1 := b.NewValue0(v.Pos, OpConst32, typ.UInt32)
+			v1.AuxInt = int32ToAuxInt(bitsMulU32(c, d).lo)
+			v.AddArg2(v0, v1)
+			return true
+		}
+		break
+	}
+	return false
+}
 func rewriteValuegeneric_OpMul32uover(v *Value) bool {
 	v_1 := v.Args[1]
 	v_0 := v.Args[0]
 	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Mul32uover (Const32 [c]) (Const32 [d]))
+	// result: (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU32(c,d).hi != 0]))
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			if v_0.Op != OpConst32 {
+				continue
+			}
+			c := auxIntToInt32(v_0.AuxInt)
+			if v_1.Op != OpConst32 {
+				continue
+			}
+			d := auxIntToInt32(v_1.AuxInt)
+			v.reset(OpMakeTuple)
+			v0 := b.NewValue0(v.Pos, OpConst32, typ.UInt32)
+			v0.AuxInt = int32ToAuxInt(bitsMulU32(c, d).lo)
+			v1 := b.NewValue0(v.Pos, OpConstBool, typ.Bool)
+			v1.AuxInt = boolToAuxInt(bitsMulU32(c, d).hi != 0)
+			v.AddArg2(v0, v1)
+			return true
+		}
+		break
+	}
 	// match: (Mul32uover <t> (Const32 [1]) x)
 	// result: (MakeTuple x (ConstBool <t.FieldType(1)> [false]))
 	for {
@@ -19082,10 +19138,62 @@ func rewriteValuegeneric_OpMul64F(v *Value) bool {
 	}
 	return false
 }
+func rewriteValuegeneric_OpMul64uhilo(v *Value) bool {
+	v_1 := v.Args[1]
+	v_0 := v.Args[0]
+	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Mul64uhilo (Const64 [c]) (Const64 [d]))
+	// result: (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).hi]) (Const64 <typ.UInt64> [bitsMulU64(c,d).lo]))
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			if v_0.Op != OpConst64 {
+				continue
+			}
+			c := auxIntToInt64(v_0.AuxInt)
+			if v_1.Op != OpConst64 {
+				continue
+			}
+			d := auxIntToInt64(v_1.AuxInt)
+			v.reset(OpMakeTuple)
+			v0 := b.NewValue0(v.Pos, OpConst64, typ.UInt64)
+			v0.AuxInt = int64ToAuxInt(bitsMulU64(c, d).hi)
+			v1 := b.NewValue0(v.Pos, OpConst64, typ.UInt64)
+			v1.AuxInt = int64ToAuxInt(bitsMulU64(c, d).lo)
+			v.AddArg2(v0, v1)
+			return true
+		}
+		break
+	}
+	return false
+}
 func rewriteValuegeneric_OpMul64uover(v *Value) bool {
 	v_1 := v.Args[1]
 	v_0 := v.Args[0]
 	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Mul64uover (Const64 [c]) (Const64 [d]))
+	// result: (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU64(c,d).hi != 0]))
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			if v_0.Op != OpConst64 {
+				continue
+			}
+			c := auxIntToInt64(v_0.AuxInt)
+			if v_1.Op != OpConst64 {
+				continue
+			}
+			d := auxIntToInt64(v_1.AuxInt)
+			v.reset(OpMakeTuple)
+			v0 := b.NewValue0(v.Pos, OpConst64, typ.UInt64)
+			v0.AuxInt = int64ToAuxInt(bitsMulU64(c, d).lo)
+			v1 := b.NewValue0(v.Pos, OpConstBool, typ.Bool)
+			v1.AuxInt = boolToAuxInt(bitsMulU64(c, d).hi != 0)
+			v.AddArg2(v0, v1)
+			return true
+		}
+		break
+	}
 	// match: (Mul64uover <t> (Const64 [1]) x)
 	// result: (MakeTuple x (ConstBool <t.FieldType(1)> [false]))
 	for {
diff --git a/test/codegen/mathbits.go b/test/codegen/mathbits.go
index a9cf4667807..873354b838e 100644
--- a/test/codegen/mathbits.go
+++ b/test/codegen/mathbits.go
@@ -6,7 +6,10 @@
 
 package codegen
 
-import "math/bits"
+import (
+	"math/bits"
+	"unsafe"
+)
 
 // ----------------------- //
 //    bits.LeadingZeros    //
@@ -957,6 +960,17 @@ func Mul64LoOnly(x, y uint64) uint64 {
 	return lo
 }
 
+func Mul64Const() (uint64, uint64) {
+	// 7133701809754865664 == 99<<56
+	// arm64:"MOVD\t[$]7133701809754865664, R1", "MOVD\t[$]88, R0"
+	return bits.Mul64(99+88<<8, 1<<56)
+}
+
+func MulUintOverflow(p *uint64) []uint64 {
+	// arm64:"CMP\t[$]72"
+	return unsafe.Slice(p, 9)
+}
+
 // --------------- //
 //    bits.Div*    //
 // --------------- //
-- 
GitLab