From 7d0cb2a2adec493b8ad9d79ef35354c8e20f0213 Mon Sep 17 00:00:00 2001 From: Keith Randall <khr@golang.org> Date: Mon, 21 Apr 2025 12:44:24 -0700 Subject: [PATCH] cmd/compile: constant fold 128-bit multiplies The full 64x64->128 multiply comes up when using bits.Mul64. The 64x64->64+overflow multiply comes up in unsafe.Slice when using a constant length. Change-Id: I298515162ca07d804b2d699d03bc957ca30a4ebc Reviewed-on: https://go-review.googlesource.com/c/go/+/667175 Reviewed-by: Junyang Shao <shaojunyang@google.com> Reviewed-by: Keith Randall <khr@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> --- .../compile/internal/ssa/_gen/generic.rules | 4 + src/cmd/compile/internal/ssa/rewrite.go | 11 ++ .../compile/internal/ssa/rewritegeneric.go | 108 ++++++++++++++++++ test/codegen/mathbits.go | 16 ++- 4 files changed, 138 insertions(+), 1 deletion(-) diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules index eb04d03e496..baa26133fe0 100644 --- a/src/cmd/compile/internal/ssa/_gen/generic.rules +++ b/src/cmd/compile/internal/ssa/_gen/generic.rules @@ -139,6 +139,10 @@ (Mul64 (Const64 [c]) (Const64 [d])) => (Const64 [c*d]) (Mul32F (Const32F [c]) (Const32F [d])) && c*d == c*d => (Const32F [c*d]) (Mul64F (Const64F [c]) (Const64F [d])) && c*d == c*d => (Const64F [c*d]) +(Mul32uhilo (Const32 [c]) (Const32 [d])) => (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).hi]) (Const32 <typ.UInt32> [bitsMulU32(c,d).lo])) +(Mul64uhilo (Const64 [c]) (Const64 [d])) => (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).hi]) (Const64 <typ.UInt64> [bitsMulU64(c,d).lo])) +(Mul32uover (Const32 [c]) (Const32 [d])) => (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU32(c,d).hi != 0])) +(Mul64uover (Const64 [c]) (Const64 [d])) => (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU64(c,d).hi != 0])) (And8 (Const8 [c]) (Const8 [d])) => (Const8 [c&d]) (And16 (Const16 [c]) (Const16 [d])) => (Const16 [c&d]) diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go index ed79d515463..859814a2d76 100644 --- a/src/cmd/compile/internal/ssa/rewrite.go +++ b/src/cmd/compile/internal/ssa/rewrite.go @@ -2584,3 +2584,14 @@ func bitsAdd64(x, y, carry int64) (r struct{ sum, carry int64 }) { r.sum, r.carry = int64(s), int64(c) return } + +func bitsMulU64(x, y int64) (r struct{ hi, lo int64 }) { + hi, lo := bits.Mul64(uint64(x), uint64(y)) + r.hi, r.lo = int64(hi), int64(lo) + return +} +func bitsMulU32(x, y int32) (r struct{ hi, lo int32 }) { + hi, lo := bits.Mul32(uint32(x), uint32(y)) + r.hi, r.lo = int32(hi), int32(lo) + return +} diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index 4fdb22b868e..b8866cc5628 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -246,12 +246,16 @@ func rewriteValuegeneric(v *Value) bool { return rewriteValuegeneric_OpMul32(v) case OpMul32F: return rewriteValuegeneric_OpMul32F(v) + case OpMul32uhilo: + return rewriteValuegeneric_OpMul32uhilo(v) case OpMul32uover: return rewriteValuegeneric_OpMul32uover(v) case OpMul64: return rewriteValuegeneric_OpMul64(v) case OpMul64F: return rewriteValuegeneric_OpMul64F(v) + case OpMul64uhilo: + return rewriteValuegeneric_OpMul64uhilo(v) case OpMul64uover: return rewriteValuegeneric_OpMul64uover(v) case OpMul8: @@ -18766,10 +18770,62 @@ func rewriteValuegeneric_OpMul32F(v *Value) bool { } return false } +func rewriteValuegeneric_OpMul32uhilo(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Mul32uhilo (Const32 [c]) (Const32 [d])) + // result: (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).hi]) (Const32 <typ.UInt32> [bitsMulU32(c,d).lo])) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpConst32 { + continue + } + c := auxIntToInt32(v_0.AuxInt) + if v_1.Op != OpConst32 { + continue + } + d := auxIntToInt32(v_1.AuxInt) + v.reset(OpMakeTuple) + v0 := b.NewValue0(v.Pos, OpConst32, typ.UInt32) + v0.AuxInt = int32ToAuxInt(bitsMulU32(c, d).hi) + v1 := b.NewValue0(v.Pos, OpConst32, typ.UInt32) + v1.AuxInt = int32ToAuxInt(bitsMulU32(c, d).lo) + v.AddArg2(v0, v1) + return true + } + break + } + return false +} func rewriteValuegeneric_OpMul32uover(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block + typ := &b.Func.Config.Types + // match: (Mul32uover (Const32 [c]) (Const32 [d])) + // result: (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU32(c,d).hi != 0])) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpConst32 { + continue + } + c := auxIntToInt32(v_0.AuxInt) + if v_1.Op != OpConst32 { + continue + } + d := auxIntToInt32(v_1.AuxInt) + v.reset(OpMakeTuple) + v0 := b.NewValue0(v.Pos, OpConst32, typ.UInt32) + v0.AuxInt = int32ToAuxInt(bitsMulU32(c, d).lo) + v1 := b.NewValue0(v.Pos, OpConstBool, typ.Bool) + v1.AuxInt = boolToAuxInt(bitsMulU32(c, d).hi != 0) + v.AddArg2(v0, v1) + return true + } + break + } // match: (Mul32uover <t> (Const32 [1]) x) // result: (MakeTuple x (ConstBool <t.FieldType(1)> [false])) for { @@ -19082,10 +19138,62 @@ func rewriteValuegeneric_OpMul64F(v *Value) bool { } return false } +func rewriteValuegeneric_OpMul64uhilo(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Mul64uhilo (Const64 [c]) (Const64 [d])) + // result: (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).hi]) (Const64 <typ.UInt64> [bitsMulU64(c,d).lo])) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpConst64 { + continue + } + c := auxIntToInt64(v_0.AuxInt) + if v_1.Op != OpConst64 { + continue + } + d := auxIntToInt64(v_1.AuxInt) + v.reset(OpMakeTuple) + v0 := b.NewValue0(v.Pos, OpConst64, typ.UInt64) + v0.AuxInt = int64ToAuxInt(bitsMulU64(c, d).hi) + v1 := b.NewValue0(v.Pos, OpConst64, typ.UInt64) + v1.AuxInt = int64ToAuxInt(bitsMulU64(c, d).lo) + v.AddArg2(v0, v1) + return true + } + break + } + return false +} func rewriteValuegeneric_OpMul64uover(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block + typ := &b.Func.Config.Types + // match: (Mul64uover (Const64 [c]) (Const64 [d])) + // result: (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU64(c,d).hi != 0])) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpConst64 { + continue + } + c := auxIntToInt64(v_0.AuxInt) + if v_1.Op != OpConst64 { + continue + } + d := auxIntToInt64(v_1.AuxInt) + v.reset(OpMakeTuple) + v0 := b.NewValue0(v.Pos, OpConst64, typ.UInt64) + v0.AuxInt = int64ToAuxInt(bitsMulU64(c, d).lo) + v1 := b.NewValue0(v.Pos, OpConstBool, typ.Bool) + v1.AuxInt = boolToAuxInt(bitsMulU64(c, d).hi != 0) + v.AddArg2(v0, v1) + return true + } + break + } // match: (Mul64uover <t> (Const64 [1]) x) // result: (MakeTuple x (ConstBool <t.FieldType(1)> [false])) for { diff --git a/test/codegen/mathbits.go b/test/codegen/mathbits.go index a9cf4667807..873354b838e 100644 --- a/test/codegen/mathbits.go +++ b/test/codegen/mathbits.go @@ -6,7 +6,10 @@ package codegen -import "math/bits" +import ( + "math/bits" + "unsafe" +) // ----------------------- // // bits.LeadingZeros // @@ -957,6 +960,17 @@ func Mul64LoOnly(x, y uint64) uint64 { return lo } +func Mul64Const() (uint64, uint64) { + // 7133701809754865664 == 99<<56 + // arm64:"MOVD\t[$]7133701809754865664, R1", "MOVD\t[$]88, R0" + return bits.Mul64(99+88<<8, 1<<56) +} + +func MulUintOverflow(p *uint64) []uint64 { + // arm64:"CMP\t[$]72" + return unsafe.Slice(p, 9) +} + // --------------- // // bits.Div* // // --------------- // -- GitLab