Skip to content
Snippets Groups Projects
Commit 39070da4 authored by Russ Cox's avatar Russ Cox Committed by Gopher Robot
Browse files

math/big: add shift and mul to mini-compiler

Step 3 of the mini-compiler: add the generators for the shift and mul routines.

Change-Id: I981d5b7086262c740036f5db768d3e63083984e2
Reviewed-on: https://go-review.googlesource.com/c/go/+/664937


Auto-Submit: Russ Cox <rsc@golang.org>
Reviewed-by: default avatarAlan Donovan <adonovan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
parent 2a881066
No related branches found
No related tags found
No related merge requests found
......@@ -24,10 +24,15 @@ type Arch struct {
// Registers.
regs []string // usable general registers, in allocation order
reg0 string // dedicated zero register
regCarry string // dedicated carry register
regAltCarry string // dedicated secondary carry register
regCarry string // dedicated carry register, for systems with no hardware carry bits
regAltCarry string // dedicated secondary carry register, for systems with no hardware carry bits
regTmp string // dedicated temporary register
// regShift indicates that the architecture supports
// using REG1>>REG2 and REG1<<REG2 as the first source
// operand in an arithmetic instruction. (32-bit ARM does this.)
regShift bool
// setup is called to emit any per-architecture function prologue,
// immediately after the TEXT line has been emitted.
// If setup is nil, it is taken to be a no-op.
......@@ -86,13 +91,6 @@ type Arch struct {
addF func(a *Asm, src1, src2, dst Reg, carry Carry) bool
subF func(a *Asm, src1, src2, dst Reg, carry Carry) bool
// lshF and rshF implement a.Lsh and a.Rsh
// on systems where the situation is more complicated than
// a simple instruction opcode.
// They must succeed.
lshF func(a *Asm, shift, src, dst Reg)
rshF func(a *Asm, shift, src, dst Reg)
// mulF and mulWideF implement Mul and MulWide.
// They call Fatalf if the operation is unsupported.
// An architecture can set the mul field instead of mulF.
......
......@@ -4,8 +4,6 @@
package asmgen
import "strings"
var ArchARM = &Arch{
Name: "arm",
WordBits: 32,
......@@ -20,6 +18,7 @@ var ArchARM = &Arch{
// R15 is PC.
"R0", "R1", "R2", "R3", "R4", "R5", "R6", "R7", "R8", "R9", "R11", "R12",
},
regShift: true,
mov: "MOVW",
add: "ADD",
......@@ -34,8 +33,6 @@ var ArchARM = &Arch{
and: "AND",
or: "ORR",
xor: "EOR",
lshF: armLsh,
rshF: armRsh,
mulWideF: armMulWide,
......@@ -50,14 +47,6 @@ var ArchARM = &Arch{
storeDecN: armStoreDecN,
}
func armLsh(a *Asm, shift, src, dst Reg) {
a.Printf("\tMOVW %s<<%s, %s\n", src, strings.TrimPrefix(shift.String(), "$"), dst)
}
func armRsh(a *Asm, shift, src, dst Reg) {
a.Printf("\tMOVW %s>>%s, %s\n", src, strings.TrimPrefix(shift.String(), "$"), dst)
}
func armMulWide(a *Asm, src1, src2, dstlo, dsthi Reg) {
a.Printf("\tMULLU %s, %s, (%s, %s)\n", src1, src2, dsthi, dstlo)
}
......
......@@ -328,14 +328,28 @@ func (a *Asm) Neg(src, dst Reg) {
}
}
// HasRegShift reports whether the architecture can use shift expressions as operands.
func (a *Asm) HasRegShift() bool {
return a.Arch.regShift
}
// LshReg returns a shift-expression operand src<<shift.
// If a.HasRegShift() == false, LshReg panics.
func (a *Asm) LshReg(shift, src Reg) Reg {
if !a.HasRegShift() {
a.Fatalf("no reg shift")
}
return Reg{fmt.Sprintf("%s<<%s", src, strings.TrimPrefix(shift.name, "$"))}
}
// Lsh emits dst = src << shift.
// It may modify the carry flag.
func (a *Asm) Lsh(shift, src, dst Reg) {
if need := a.hint(HintShiftCount); need != "" && shift.name != need && !shift.IsImm() {
a.Fatalf("shift count not in %s", need)
}
if a.Arch.lshF != nil {
a.Arch.lshF(a, shift, src, dst)
if a.HasRegShift() {
a.Mov(a.LshReg(shift, src), dst)
return
}
a.op3(a.Arch.lsh, shift, src, dst)
......@@ -353,14 +367,23 @@ func (a *Asm) LshWide(shift, adj, src, dst Reg) {
a.op3(fmt.Sprintf("%s %s,", a.Arch.lshd, shift), adj, src, dst)
}
// RshReg returns a shift-expression operand src>>shift.
// If a.HasRegShift() == false, RshReg panics.
func (a *Asm) RshReg(shift, src Reg) Reg {
if !a.HasRegShift() {
a.Fatalf("no reg shift")
}
return Reg{fmt.Sprintf("%s>>%s", src, strings.TrimPrefix(shift.name, "$"))}
}
// Rsh emits dst = src >> shift.
// It may modify the carry flag.
func (a *Asm) Rsh(shift, src, dst Reg) {
if need := a.hint(HintShiftCount); need != "" && shift.name != need && !shift.IsImm() {
a.Fatalf("shift count not in %s", need)
}
if a.Arch.rshF != nil {
a.Arch.rshF(a, shift, src, dst)
if a.HasRegShift() {
a.Mov(a.RshReg(shift, src), dst)
return
}
a.op3(a.Arch.rsh, shift, src, dst)
......
......@@ -36,16 +36,19 @@ func loop(x int) int {
s := 0
for i := 1; i < x; i++ {
s += i
if s == 98 {
if s == 98 { // useful for jmpEqual
return 99
}
if s == 99 {
return 100
}
if s == 0 {
if s == 0 { // useful for jmpZero
return 101
}
s += 2
if s != 0 { // useful for jmpNonZero
s *= 3
}
s += 2 // keep last condition from being inverted
}
return s
}
......
......@@ -33,5 +33,9 @@ func generate(arch *Arch) (file string, data []byte) {
a := NewAsm(arch)
addOrSubVV(a, "addVV")
addOrSubVV(a, "subVV")
shiftVU(a, "lshVU")
shiftVU(a, "rshVU")
mulAddVWW(a)
addMulVVWW(a)
return file, a.out.Bytes()
}
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asmgen
// mulAddVWW generates mulAddVWW, which does z, c = x*m + a.
func mulAddVWW(a *Asm) {
f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)")
if a.AltCarry().Valid() {
addMulVirtualCarry(f, 0)
return
}
addMul(f, "", "x", 0)
}
// addMulVVWW generates addMulVVWW which does z, c = x + y*m + a.
// (A more pedantic name would be addMulAddVVWW.)
func addMulVVWW(a *Asm) {
f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)")
// If the architecture has virtual carries, emit that version unconditionally.
if a.AltCarry().Valid() {
addMulVirtualCarry(f, 1)
return
}
// If the architecture optionally has two carries, test and emit both versions.
if a.JmpEnable(OptionAltCarry, "altcarry") {
regs := a.RegsUsed()
addMul(f, "x", "y", 1)
a.Label("altcarry")
a.SetOption(OptionAltCarry, true)
a.SetRegsUsed(regs)
addMulAlt(f)
a.SetOption(OptionAltCarry, false)
return
}
// Otherwise emit the one-carry form.
addMul(f, "x", "y", 1)
}
// Computing z = addsrc + m*mulsrc + a, we need:
//
// for i := range z {
// lo, hi := m * mulsrc[i]
// lo, carry = bits.Add(lo, a, 0)
// lo, carryAlt = bits.Add(lo, addsrc[i], 0)
// z[i] = lo
// a = hi + carry + carryAlt // cannot overflow
// }
//
// The final addition cannot overflow because after processing N words,
// the maximum possible value is (for a 64-bit system):
//
// (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1)
// = (2**64)*(2**64N - 1) + (2**64 - 1)
// = 2**64(N+1) - 1,
//
// which fits in N+1 words (the high order one being the new value of a).
//
// (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.)
//
// If we unroll the loop a bit, then we can chain the carries in two passes.
// Consider:
//
// lo0, hi0 := m * mulsrc[i]
// lo0, carry = bits.Add(lo0, a, 0)
// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
// z[i] = lo0
// a = hi + carry + carryAlt // cannot overflow
//
// lo1, hi1 := m * mulsrc[i]
// lo1, carry = bits.Add(lo1, a, 0)
// lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
// z[i] = lo1
// a = hi + carry + carryAlt // cannot overflow
//
// lo2, hi2 := m * mulsrc[i]
// lo2, carry = bits.Add(lo2, a, 0)
// lo2, carryAlt = bits.Add(lo2, addsrc[i], 0)
// z[i] = lo2
// a = hi + carry + carryAlt // cannot overflow
//
// lo3, hi3 := m * mulsrc[i]
// lo3, carry = bits.Add(lo3, a, 0)
// lo3, carryAlt = bits.Add(lo3, addsrc[i], 0)
// z[i] = lo3
// a = hi + carry + carryAlt // cannot overflow
//
// There are three ways we can optimize this sequence.
//
// (1) Reordering, we can chain carries so that we can use one hardware carry flag
// but amortize the cost of saving and restoring it across multiple instructions:
//
// // multiply
// lo0, hi0 := m * mulsrc[i]
// lo1, hi1 := m * mulsrc[i+1]
// lo2, hi2 := m * mulsrc[i+2]
// lo3, hi3 := m * mulsrc[i+3]
//
// lo0, carry = bits.Add(lo0, a, 0)
// lo1, carry = bits.Add(lo1, hi0, carry)
// lo2, carry = bits.Add(lo2, hi1, carry)
// lo3, carry = bits.Add(lo3, hi2, carry)
// a = hi3 + carry // cannot overflow
//
// // add
// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
// lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
// lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
// lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
// a = a + carryAlt // cannot overflow
//
// z[i] = lo0
// z[i+1] = lo1
// z[i+2] = lo2
// z[i+3] = lo3
//
// addMul takes this approach, using the hardware carry flag
// first for carry and then for carryAlt.
//
// (2) addMulAlt assumes there are two hardware carry flags available.
// It dedicates one each to carry and carryAlt, so that a multi-block
// unrolling can keep the flags in hardware across all the blocks.
// So even if the block size is 1, the code can do:
//
// // multiply and add
// lo0, hi0 := m * mulsrc[i]
// lo0, carry = bits.Add(lo0, a, 0)
// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
// z[i] = lo0
//
// lo1, hi1 := m * mulsrc[i+1]
// lo1, carry = bits.Add(lo1, hi0, carry)
// lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
// z[i+1] = lo1
//
// lo2, hi2 := m * mulsrc[i+2]
// lo2, carry = bits.Add(lo2, hi1, carry)
// lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
// z[i+2] = lo2
//
// lo3, hi3 := m * mulsrc[i+3]
// lo3, carry = bits.Add(lo3, hi2, carry)
// lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
// z[i+3] = lo2
//
// a = hi3 + carry + carryAlt // cannot overflow
//
// (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits
// (loong64, mips, riscv64), cutting the number of actual instructions almost by half.
// Look again at the original word-at-a-time version:
//
// lo1, hi1 := m * mulsrc[i]
// lo1, carry = bits.Add(lo1, a, 0)
// lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
// z[i] = lo1
// a = hi + carry + carryAlt // cannot overflow
//
// Although it uses four adds per word, those are cheap adds: the two bits.Add adds
// use two instructions each (ADD+SLTU) and the final + adds only use one ADD each,
// for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use
// only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to
// five instruction each, for a total of 10 instructions per word. So the word-at-a-time
// loop is actually better. And we can reorder things slightly to use only a single carry bit:
//
// lo1, hi1 := m * mulsrc[i]
// lo1, carry = bits.Add(lo1, a, 0)
// a = hi + carry
// lo1, carry = bits.Add(lo1, addsrc[i], 0)
// a = a + carry
// z[i] = lo1
func addMul(f *Func, addsrc, mulsrc string, mulIndex int) {
a := f.Asm
mh := HintNone
if a.Arch == Arch386 && addsrc != "" {
mh = HintMemOK // too few registers otherwise
}
m := f.ArgHint("m", mh)
c := f.Arg("a")
n := f.Arg("z_len")
p := f.Pipe()
if addsrc != "" {
p.SetHint(addsrc, HintMemOK)
}
p.SetHint(mulsrc, HintMulSrc)
unroll := []int{1, 4}
switch a.Arch {
case Arch386:
unroll = []int{1} // too few registers
case ArchARM:
p.SetMaxColumns(2) // too few registers (but more than 386)
case ArchARM64:
unroll = []int{1, 8} // 5% speedup on c4as16
}
// See the large comment above for an explanation of the code being generated.
// This is optimization strategy 1.
p.Start(n, unroll...)
p.Loop(func(in, out [][]Reg) {
a.Comment("multiply")
prev := c
flag := SetCarry
for i, x := range in[mulIndex] {
hi := a.RegHint(HintMulHi)
a.MulWide(m, x, x, hi)
a.Add(prev, x, x, flag)
flag = UseCarry | SetCarry
if prev != c {
a.Free(prev)
}
out[0][i] = x
prev = hi
}
a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry)
if addsrc != "" {
a.Comment("add")
flag := SetCarry
for i, x := range in[0] {
a.Add(x, out[0][i], out[0][i], flag)
flag = UseCarry | SetCarry
}
a.Add(a.Imm(0), c, c, UseCarry|SmashCarry)
}
p.StoreN(out)
})
f.StoreArg(c, "c")
a.Ret()
}
func addMulAlt(f *Func) {
a := f.Asm
m := f.ArgHint("m", HintMulSrc)
c := f.Arg("a")
n := f.Arg("z_len")
// On amd64, we need a non-immediate for the AtUnrollEnd adds.
r0 := a.ZR()
if !r0.Valid() {
r0 = a.Reg()
a.Mov(a.Imm(0), r0)
}
p := f.Pipe()
p.SetLabel("alt")
p.SetHint("x", HintMemOK)
p.SetHint("y", HintMemOK)
if a.Arch == ArchAMD64 {
p.SetMaxColumns(2)
}
// See the large comment above for an explanation of the code being generated.
// This is optimization strategy (2).
var hi Reg
prev := c
p.Start(n, 1, 8)
p.AtUnrollStart(func() {
a.Comment("multiply and add")
a.ClearCarry(AddCarry | AltCarry)
a.ClearCarry(AddCarry)
hi = a.Reg()
})
p.AtUnrollEnd(func() {
a.Add(r0, prev, c, UseCarry|SmashCarry)
a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry)
prev = c
})
p.Loop(func(in, out [][]Reg) {
for i, y := range in[1] {
x := in[0][i]
lo := y
if lo.IsMem() {
lo = a.Reg()
}
a.MulWide(m, y, lo, hi)
a.Add(prev, lo, lo, UseCarry|SetCarry)
a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry)
out[0][i] = lo
prev, hi = hi, prev
}
p.StoreN(out)
})
f.StoreArg(c, "c")
a.Ret()
}
func addMulVirtualCarry(f *Func, mulIndex int) {
a := f.Asm
m := f.Arg("m")
c := f.Arg("a")
n := f.Arg("z_len")
// See the large comment above for an explanation of the code being generated.
// This is optimization strategy (3).
p := f.Pipe()
p.Start(n, 1, 4)
p.Loop(func(in, out [][]Reg) {
a.Comment("synthetic carry, one column at a time")
lo, hi := a.Reg(), a.Reg()
for i, x := range in[mulIndex] {
a.MulWide(m, x, lo, hi)
if mulIndex == 1 {
a.Add(in[0][i], lo, lo, SetCarry)
a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry)
}
a.Add(c, lo, x, SetCarry)
a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry)
out[0][i] = x
}
p.StoreN(out)
})
f.StoreArg(c, "c")
a.Ret()
}
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asmgen
// shiftVU generates lshVU and rshVU, which do
// z, c = x << s and z, c = x >> s, for 0 < s < _W.
func shiftVU(a *Asm, name string) {
// Because these routines can be called for z.Lsh(z, N) and z.Rsh(z, N),
// the input and output slices may be aliased at different offsets.
// For example (on 64-bit systems), during z.Lsh(z, 65), &z[0] == &x[1],
// and during z.Rsh(z, 65), &z[1] == &x[0].
// For left shift, we must process the slices from len(z)-1 down to 0,
// so that we don't overwrite a word before we need to read it.
// For right shift, we must process the slices from 0 up to len(z)-1.
// The different traversals at least make the two cases more consistent,
// since we're always delaying the output by one word compared
// to the input.
f := a.Func("func " + name + "(z, x []Word, s uint) (c Word)")
// Check for no input early, since we need to start by reading 1 word.
n := f.Arg("z_len")
a.JmpZero(n, "ret0")
// Start loop by reading first input word.
s := f.ArgHint("s", HintShiftCount)
p := f.Pipe()
if name == "lshVU" {
p.SetBackward()
}
unroll := []int{1, 4}
if a.Arch == Arch386 {
unroll = []int{1} // too few registers for more
p.SetUseIndexCounter()
}
p.LoadPtrs(n)
a.Comment("shift first word into carry")
prev := p.LoadN(1)[0][0]
// Decide how to shift. On systems with a wide shift (x86), use that.
// Otherwise, we need shift by s and negative (reverse) shift by 64-s or 32-s.
shift := a.Lsh
shiftWide := a.LshWide
negShift := a.Rsh
negShiftReg := a.RshReg
if name == "rshVU" {
shift = a.Rsh
shiftWide = a.RshWide
negShift = a.Lsh
negShiftReg = a.LshReg
}
if a.Arch.HasShiftWide() {
// Use wide shift to avoid needing negative shifts.
// The invariant is that prev holds the previous word (not shifted at all),
// to be used as input into the wide shift.
// After the loop finishes, prev holds the final output word to be written.
c := a.Reg()
shiftWide(s, prev, a.Imm(0), c)
f.StoreArg(c, "c")
a.Free(c)
a.Comment("shift remaining words")
p.Start(n, unroll...)
p.Loop(func(in [][]Reg, out [][]Reg) {
// We reuse the input registers as output, delayed one cycle; prev is the first output.
// After writing the outputs to memory, we can copy the final x value into prev
// for the next iteration.
old := prev
for i, x := range in[0] {
shiftWide(s, x, old, old)
out[0][i] = old
old = x
}
p.StoreN(out)
a.Mov(old, prev)
})
a.Comment("store final shifted bits")
shift(s, prev, prev)
} else {
// Construct values from x << s and x >> (64-s).
// After the first word has been processed, the invariant is that
// prev holds x << s, to be used as the high bits of the next output word,
// once we find the low bits after reading the next input word.
// After the loop finishes, prev holds the final output word to be written.
sNeg := a.Reg()
a.Mov(a.Imm(a.Arch.WordBits), sNeg)
a.Sub(s, sNeg, sNeg, SmashCarry)
c := a.Reg()
negShift(sNeg, prev, c)
shift(s, prev, prev)
f.StoreArg(c, "c")
a.Free(c)
a.Comment("shift remaining words")
p.Start(n, unroll...)
p.Loop(func(in, out [][]Reg) {
if a.HasRegShift() {
// ARM (32-bit) allows shifts in most arithmetic expressions,
// including OR, letting us combine the negShift and a.Or.
// The simplest way to manage the registers is to do StoreN for
// one output at a time, and since we don't use multi-register
// stores on ARM, that doesn't hurt us.
out[0] = out[0][:1]
for _, x := range in[0] {
a.Or(negShiftReg(sNeg, x), prev, prev)
out[0][0] = prev
p.StoreN(out)
shift(s, x, prev)
}
return
}
// We reuse the input registers as output, delayed one cycle; z0 is the first output.
z0 := a.Reg()
z := z0
for i, x := range in[0] {
negShift(sNeg, x, z)
a.Or(prev, z, z)
shift(s, x, prev)
out[0][i] = z
z = x
}
p.StoreN(out)
})
a.Comment("store final shifted bits")
}
p.StoreN([][]Reg{{prev}})
p.Done()
a.Free(s)
a.Ret()
// Return 0, used from above.
a.Label("ret0")
f.StoreArg(a.Imm(0), "c")
a.Ret()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment