From a9c9cc07ac0d3dc73865a57e6ce45c22ada3b5c9 Mon Sep 17 00:00:00 2001
From: Russ Cox <rsc@golang.org>
Date: Mon, 20 Nov 2023 11:22:48 +1100
Subject: [PATCH] iter, runtime: add coroutine support

The exported API is only available with GOEXPERIMENT=rangefunc.
This will let Go 1.22 users who want to experiment with rangefuncs
access an efficient implementation of iter.Pull and iter.Pull2.

For #61897.

Change-Id: I6ef5fa8f117567efe4029b7b8b0f4d9b85697fb7
Reviewed-on: https://go-review.googlesource.com/c/go/+/543319
Reviewed-by: Michael Knyszek <mknyszek@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
---
 src/cmd/dist/test.go              |  11 ++
 src/cmd/internal/objabi/funcid.go |   1 +
 src/go/build/deps_test.go         |   3 +
 src/internal/abi/symtab.go        |   1 +
 src/iter/iter.go                  | 161 +++++++++++++++++++++++++++++
 src/iter/pull_test.go             | 118 +++++++++++++++++++++
 src/runtime/coro.go               | 165 ++++++++++++++++++++++++++++++
 src/runtime/proc.go               |   8 +-
 src/runtime/runtime2.go           |   5 +
 src/runtime/sizeof_test.go        |   6 +-
 src/runtime/traceback.go          |   2 +-
 11 files changed, 475 insertions(+), 6 deletions(-)
 create mode 100644 src/iter/iter.go
 create mode 100644 src/iter/pull_test.go
 create mode 100644 src/runtime/coro.go

diff --git a/src/cmd/dist/test.go b/src/cmd/dist/test.go
index 4450129e087..5e62bbf4c22 100644
--- a/src/cmd/dist/test.go
+++ b/src/cmd/dist/test.go
@@ -719,6 +719,17 @@ func (t *tester) registerTests() {
 			})
 	}
 
+	// GOEXPERIMENT=rangefunc tests
+	if !t.compileOnly {
+		t.registerTest("GOEXPERIMENT=rangefunc go test iter",
+			&goTest{
+				variant: "iter",
+				short:   t.short,
+				env:     []string{"GOEXPERIMENT=rangefunc"},
+				pkg:     "iter",
+			})
+	}
+
 	// GODEBUG=gcstoptheworld=2 tests. We only run these in long-test
 	// mode (with GO_TEST_SHORT=0) because this is just testing a
 	// non-critical debug setting.
diff --git a/src/cmd/internal/objabi/funcid.go b/src/cmd/internal/objabi/funcid.go
index 007107e778f..d9b47f1ec91 100644
--- a/src/cmd/internal/objabi/funcid.go
+++ b/src/cmd/internal/objabi/funcid.go
@@ -14,6 +14,7 @@ var funcIDs = map[string]abi.FuncID{
 	"asmcgocall":         abi.FuncID_asmcgocall,
 	"asyncPreempt":       abi.FuncID_asyncPreempt,
 	"cgocallback":        abi.FuncID_cgocallback,
+	"corostart":          abi.FuncID_corostart,
 	"debugCallV2":        abi.FuncID_debugCallV2,
 	"gcBgMarkWorker":     abi.FuncID_gcBgMarkWorker,
 	"rt0_go":             abi.FuncID_rt0_go,
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go
index 1b93e78c706..7ce8d346b40 100644
--- a/src/go/build/deps_test.go
+++ b/src/go/build/deps_test.go
@@ -83,6 +83,9 @@ var depsRules = `
 	< internal/oserror, math/bits
 	< RUNTIME;
 
+	internal/race
+	< iter;
+
 	# slices depends on unsafe for overlapping check, cmp for comparison
 	# semantics, and math/bits for # calculating bitlength of numbers.
 	unsafe, cmp, math/bits
diff --git a/src/internal/abi/symtab.go b/src/internal/abi/symtab.go
index bf6ea82f1ce..ce1b6501551 100644
--- a/src/internal/abi/symtab.go
+++ b/src/internal/abi/symtab.go
@@ -44,6 +44,7 @@ const (
 	FuncID_asmcgocall
 	FuncID_asyncPreempt
 	FuncID_cgocallback
+	FuncID_corostart
 	FuncID_debugCallV2
 	FuncID_gcBgMarkWorker
 	FuncID_goexit
diff --git a/src/iter/iter.go b/src/iter/iter.go
new file mode 100644
index 00000000000..240df00f7f4
--- /dev/null
+++ b/src/iter/iter.go
@@ -0,0 +1,161 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build goexperiment.rangefunc
+
+// Package iter provides basic definitions and operations
+// related to iteration in Go.
+//
+// This package is experimental and can only be imported
+// when building with GOEXPERIMENT=rangefunc.
+package iter
+
+import (
+	"internal/race"
+	"unsafe"
+	_ "unsafe"
+) // for linkname
+
+// Seq is an iterator over sequences of individual values.
+// When called as seq(yield), seq calls yield(v) for each value v in the sequence,
+// stopping early if yield returns false.
+type Seq[V any] func(yield func(V) bool)
+
+// Seq2 is an iterator over sequences of pairs of values, most commonly key-value pairs.
+// When called as seq(yield), seq calls yield(k, v) for each pair (k, v) in the sequence,
+// stopping early if yield returns false.
+type Seq2[K, V any] func(yield func(K, V) bool)
+
+type coro struct{}
+
+//go:linkname newcoro runtime.newcoro
+func newcoro(func(*coro)) *coro
+
+//go:linkname coroswitch runtime.coroswitch
+func coroswitch(*coro)
+
+// Pull converts the “push-style” iterator sequence seq
+// into a “pull-style” iterator accessed by the two functions
+// next and stop.
+//
+// Next returns the next value in the sequence
+// and a boolean indicating whether the value is valid.
+// When the sequence is over, next returns the zero V and false.
+// It is valid to call next after reaching the end of the sequence
+// or after calling stop. These calls will continue
+// to return the zero V and false.
+//
+// Stop ends the iteration. It must be called when the caller is
+// no longer interested in next values and next has not yet
+// signaled that the sequence is over (with a false boolean return).
+// It is valid to call stop multiple times and when next has
+// already returned false.
+//
+// It is an error to call next or stop from multiple goroutines
+// simultaneously.
+func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
+	var (
+		v     V
+		ok    bool
+		done  bool
+		racer int
+	)
+	c := newcoro(func(c *coro) {
+		race.Acquire(unsafe.Pointer(&racer))
+		yield := func(v1 V) bool {
+			if done {
+				return false
+			}
+			v, ok = v1, true
+			race.Release(unsafe.Pointer(&racer))
+			coroswitch(c)
+			race.Acquire(unsafe.Pointer(&racer))
+			return !done
+		}
+		seq(yield)
+		var v0 V
+		v, ok = v0, false
+		done = true
+		race.Release(unsafe.Pointer(&racer))
+	})
+	next = func() (v1 V, ok1 bool) {
+		race.Write(unsafe.Pointer(&racer)) // detect races
+		if done {
+			return
+		}
+		race.Release(unsafe.Pointer(&racer))
+		coroswitch(c)
+		race.Acquire(unsafe.Pointer(&racer))
+		return v, ok
+	}
+	stop = func() {
+		race.Write(unsafe.Pointer(&racer)) // detect races
+		if !done {
+			done = true
+			race.Release(unsafe.Pointer(&racer))
+			coroswitch(c)
+			race.Acquire(unsafe.Pointer(&racer))
+		}
+	}
+	return next, stop
+}
+
+// Pull2 converts the “push-style” iterator sequence seq
+// into a “pull-style” iterator accessed by the two functions
+// next and stop.
+//
+// Next returns the next pair in the sequence
+// and a boolean indicating whether the pair is valid.
+// When the sequence is over, next returns a pair of zero values and false.
+// It is valid to call next after reaching the end of the sequence
+// or after calling stop. These calls will continue
+// to return a pair of zero values and false.
+//
+// Stop ends the iteration. It must be called when the caller is
+// no longer interested in next values and next has not yet
+// signaled that the sequence is over (with a false boolean return).
+// It is valid to call stop multiple times and when next has
+// already returned false.
+//
+// It is an error to call next or stop from multiple goroutines
+// simultaneously.
+func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
+	var (
+		k    K
+		v    V
+		ok   bool
+		done bool
+	)
+	c := newcoro(func(c *coro) {
+		yield := func(k1 K, v1 V) bool {
+			if done {
+				return false
+			}
+			k, v, ok = k1, v1, true
+			coroswitch(c)
+			return !done
+		}
+		seq(yield)
+		var k0 K
+		var v0 V
+		k, v, ok = k0, v0, false
+		done = true
+	})
+	next = func() (k1 K, v1 V, ok1 bool) {
+		race.Write(unsafe.Pointer(&c)) // detect races
+		if done {
+			return
+		}
+		coroswitch(c)
+		return k, v, ok
+	}
+	stop = func() {
+		race.Write(unsafe.Pointer(&c)) // detect races
+		if !done {
+			done = true
+			coroswitch(c)
+		}
+	}
+	return next, stop
+}
diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go
new file mode 100644
index 00000000000..38e0ee993a1
--- /dev/null
+++ b/src/iter/pull_test.go
@@ -0,0 +1,118 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build goexperiment.rangefunc
+
+package iter
+
+import (
+	"fmt"
+	"runtime"
+	"testing"
+)
+
+func count(n int) Seq[int] {
+	return func(yield func(int) bool) {
+		for i := range n {
+			if !yield(i) {
+				break
+			}
+		}
+	}
+}
+
+func squares(n int) Seq2[int, int64] {
+	return func(yield func(int, int64) bool) {
+		for i := range n {
+			if !yield(i, int64(i)*int64(i)) {
+				break
+			}
+		}
+	}
+}
+
+func TestPull(t *testing.T) {
+
+	for end := 0; end <= 3; end++ {
+		t.Run(fmt.Sprint(end), func(t *testing.T) {
+			ng := runtime.NumGoroutine()
+			wantNG := func(want int) {
+				if xg := runtime.NumGoroutine() - ng; xg != want {
+					t.Helper()
+					t.Errorf("have %d extra goroutines, want %d", xg, want)
+				}
+			}
+			wantNG(0)
+			next, stop := Pull(count(3))
+			wantNG(1)
+			for i := range end {
+				v, ok := next()
+				if v != i || ok != true {
+					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, i, true)
+				}
+				wantNG(1)
+			}
+			wantNG(1)
+			if end < 3 {
+				stop()
+				wantNG(0)
+			}
+			for range 2 {
+				v, ok := next()
+				if v != 0 || ok != false {
+					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, 0, false)
+				}
+				wantNG(0)
+			}
+			wantNG(0)
+
+			stop()
+			stop()
+			stop()
+			wantNG(0)
+		})
+	}
+}
+
+func TestPull2(t *testing.T) {
+	for end := 0; end <= 3; end++ {
+		t.Run(fmt.Sprint(end), func(t *testing.T) {
+			ng := runtime.NumGoroutine()
+			wantNG := func(want int) {
+				if xg := runtime.NumGoroutine() - ng; xg != want {
+					t.Helper()
+					t.Errorf("have %d extra goroutines, want %d", xg, want)
+				}
+			}
+			wantNG(0)
+			next, stop := Pull2(squares(3))
+			wantNG(1)
+			for i := range end {
+				k, v, ok := next()
+				if k != i || v != int64(i*i) || ok != true {
+					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, i, i*i, true)
+				}
+				wantNG(1)
+			}
+			wantNG(1)
+			if end < 3 {
+				stop()
+				wantNG(0)
+			}
+			for range 2 {
+				k, v, ok := next()
+				if v != 0 || ok != false {
+					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, 0, 0, false)
+				}
+				wantNG(0)
+			}
+			wantNG(0)
+
+			stop()
+			stop()
+			stop()
+			wantNG(0)
+		})
+	}
+}
diff --git a/src/runtime/coro.go b/src/runtime/coro.go
new file mode 100644
index 00000000000..0d6666e343b
--- /dev/null
+++ b/src/runtime/coro.go
@@ -0,0 +1,165 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package runtime
+
+import "unsafe"
+
+// A coro represents extra concurrency without extra parallelism,
+// as would be needed for a coroutine implementation.
+// The coro does not represent a specific coroutine, only the ability
+// to do coroutine-style control transfers.
+// It can be thought of as like a special channel that always has
+// a goroutine blocked on it. If another goroutine calls coroswitch(c),
+// the caller becomes the goroutine blocked in c, and the goroutine
+// formerly blocked in c starts running.
+// These switches continue until a call to coroexit(c),
+// which ends the use of the coro by releasing the blocked
+// goroutine in c and exiting the current goroutine.
+//
+// Coros are heap allocated and garbage collected, so that user code
+// can hold a pointer to a coro without causing potential dangling
+// pointer errors.
+type coro struct {
+	gp guintptr
+	f  func(*coro)
+}
+
+//go:linkname newcoro
+
+// newcoro creates a new coro containing a
+// goroutine blocked waiting to run f
+// and returns that coro.
+func newcoro(f func(*coro)) *coro {
+	c := new(coro)
+	c.f = f
+	pc := getcallerpc()
+	gp := getg()
+	systemstack(func() {
+		start := corostart
+		startfv := *(**funcval)(unsafe.Pointer(&start))
+		gp = newproc1(startfv, gp, pc)
+	})
+	gp.coroarg = c
+	gp.waitreason = waitReasonCoroutine
+	casgstatus(gp, _Grunnable, _Gwaiting)
+	c.gp.set(gp)
+	return c
+}
+
+//go:linkname corostart
+
+// corostart is the entry func for a new coroutine.
+// It runs the coroutine user function f passed to corostart
+// and then calls coroexit to remove the extra concurrency.
+func corostart() {
+	gp := getg()
+	c := gp.coroarg
+	gp.coroarg = nil
+
+	c.f(c)
+	coroexit(c)
+}
+
+// coroexit is like coroswitch but closes the coro
+// and exits the current goroutine
+func coroexit(c *coro) {
+	gp := getg()
+	gp.coroarg = c
+	gp.coroexit = true
+	mcall(coroswitch_m)
+}
+
+//go:linkname coroswitch
+
+// coroswitch switches to the goroutine blocked on c
+// and then blocks the current goroutine on c.
+func coroswitch(c *coro) {
+	gp := getg()
+	gp.coroarg = c
+	mcall(coroswitch_m)
+}
+
+// coroswitch_m is the implementation of coroswitch
+// that runs on the m stack.
+//
+// Note: Coroutine switches are expected to happen at
+// an order of magnitude (or more) higher frequency
+// than regular goroutine switches, so this path is heavily
+// optimized to remove unnecessary work.
+// The fast path here is three CAS: the one at the top on gp.atomicstatus,
+// the one in the middle to choose the next g,
+// and the one at the bottom on gnext.atomicstatus.
+// It is important not to add more atomic operations or other
+// expensive operations to the fast path.
+func coroswitch_m(gp *g) {
+	// TODO(rsc,mknyszek): add tracing support in a lightweight manner.
+	// Probably the tracer will need a global bool (set and cleared during STW)
+	// that this code can check to decide whether to use trace.gen.Load();
+	// we do not want to do the atomic load all the time, especially when
+	// tracer use is relatively rare.
+	c := gp.coroarg
+	gp.coroarg = nil
+	exit := gp.coroexit
+	gp.coroexit = false
+	mp := gp.m
+
+	if exit {
+		gdestroy(gp)
+		gp = nil
+	} else {
+		// If we can CAS ourselves directly from running to waiting, so do,
+		// keeping the control transfer as lightweight as possible.
+		gp.waitreason = waitReasonCoroutine
+		if !gp.atomicstatus.CompareAndSwap(_Grunning, _Gwaiting) {
+			// The CAS failed: use casgstatus, which will take care of
+			// coordinating with the garbage collector about the state change.
+			casgstatus(gp, _Grunning, _Gwaiting)
+		}
+
+		// Clear gp.m.
+		setMNoWB(&gp.m, nil)
+	}
+
+	// The goroutine stored in c is the one to run next.
+	// Swap it with ourselves.
+	var gnext *g
+	for {
+		// Note: this is a racy load, but it will eventually
+		// get the right value, and if it gets the wrong value,
+		// the c.gp.cas will fail, so no harm done other than
+		// a wasted loop iteration.
+		// The cas will also sync c.gp's
+		// memory enough that the next iteration of the racy load
+		// should see the correct value.
+		// We are avoiding the atomic load to keep this path
+		// as lightweight as absolutely possible.
+		// (The atomic load is free on x86 but not free elsewhere.)
+		next := c.gp
+		if next.ptr() == nil {
+			throw("coroswitch on exited coro")
+		}
+		var self guintptr
+		self.set(gp)
+		if c.gp.cas(next, self) {
+			gnext = next.ptr()
+			break
+		}
+	}
+
+	// Start running next, without heavy scheduling machinery.
+	// Set mp.curg and gnext.m and then update scheduling state
+	// directly if possible.
+	setGNoWB(&mp.curg, gnext)
+	setMNoWB(&gnext.m, mp)
+	if !gnext.atomicstatus.CompareAndSwap(_Gwaiting, _Grunning) {
+		// The CAS failed: use casgstatus, which will take care of
+		// coordinating with the garbage collector about the state change.
+		casgstatus(gnext, _Gwaiting, _Grunnable)
+		casgstatus(gnext, _Grunnable, _Grunning)
+	}
+
+	// Switch to gnext. Does not return.
+	gogo(&gnext.sched)
+}
diff --git a/src/runtime/proc.go b/src/runtime/proc.go
index 661dc0f1caa..aae30dc2a8f 100644
--- a/src/runtime/proc.go
+++ b/src/runtime/proc.go
@@ -4175,6 +4175,11 @@ func goexit1() {
 
 // goexit continuation on g0.
 func goexit0(gp *g) {
+	gdestroy(gp)
+	schedule()
+}
+
+func gdestroy(gp *g) {
 	mp := getg().m
 	pp := mp.p.ptr()
 
@@ -4211,7 +4216,7 @@ func goexit0(gp *g) {
 
 	if GOARCH == "wasm" { // no threads yet on wasm
 		gfput(pp, gp)
-		schedule() // never returns
+		return
 	}
 
 	if mp.lockedInt != 0 {
@@ -4234,7 +4239,6 @@ func goexit0(gp *g) {
 			mp.lockedExt = 0
 		}
 	}
-	schedule()
 }
 
 // save updates getg().sched to refer to pc and sp so that a following
diff --git a/src/runtime/runtime2.go b/src/runtime/runtime2.go
index e64be992b05..2d3fd30e633 100644
--- a/src/runtime/runtime2.go
+++ b/src/runtime/runtime2.go
@@ -483,6 +483,7 @@ type g struct {
 	// inMarkAssist indicates whether the goroutine is in mark assist.
 	// Used by the execution tracer.
 	inMarkAssist bool
+	coroexit     bool // argument to coroswitch_m
 
 	raceignore    int8  // ignore race detection events
 	nocgocallback bool  // whether disable callback from C
@@ -507,6 +508,8 @@ type g struct {
 	timer         *timer         // cached timer for time.Sleep
 	selectDone    atomic.Uint32  // are we participating in a select and did someone win the race?
 
+	coroarg *coro // argument during coroutine transfers
+
 	// goroutineProfiled indicates the status of this goroutine's stack for the
 	// current in-progress goroutine profile
 	goroutineProfiled goroutineProfileStateHolder
@@ -1124,6 +1127,7 @@ const (
 	waitReasonFlushProcCaches                         // "flushing proc caches"
 	waitReasonTraceGoroutineStatus                    // "trace goroutine status"
 	waitReasonTraceProcStatus                         // "trace proc status"
+	waitReasonCoroutine                               // "coroutine"
 )
 
 var waitReasonStrings = [...]string{
@@ -1162,6 +1166,7 @@ var waitReasonStrings = [...]string{
 	waitReasonFlushProcCaches:       "flushing proc caches",
 	waitReasonTraceGoroutineStatus:  "trace goroutine status",
 	waitReasonTraceProcStatus:       "trace proc status",
+	waitReasonCoroutine:             "coroutine",
 }
 
 func (w waitReason) String() string {
diff --git a/src/runtime/sizeof_test.go b/src/runtime/sizeof_test.go
index ccc0864ca9d..aa8caaaddae 100644
--- a/src/runtime/sizeof_test.go
+++ b/src/runtime/sizeof_test.go
@@ -17,9 +17,9 @@ import (
 func TestSizeof(t *testing.T) {
 	const _64bit = unsafe.Sizeof(uintptr(0)) == 8
 
-	g32bit := uintptr(252)
+	g32bit := uintptr(256)
 	if goexperiment.ExecTracer2 {
-		g32bit = uintptr(256)
+		g32bit = uintptr(260)
 	}
 
 	var tests = []struct {
@@ -27,7 +27,7 @@ func TestSizeof(t *testing.T) {
 		_32bit uintptr // size on 32bit platforms
 		_64bit uintptr // size on 64bit platforms
 	}{
-		{runtime.G{}, g32bit, 408}, // g, but exported for testing
+		{runtime.G{}, g32bit, 424}, // g, but exported for testing
 		{runtime.Sudog{}, 56, 88},  // sudog, but exported for testing
 	}
 
diff --git a/src/runtime/traceback.go b/src/runtime/traceback.go
index 66a1cc85ee4..1e5afc6bdde 100644
--- a/src/runtime/traceback.go
+++ b/src/runtime/traceback.go
@@ -1322,7 +1322,7 @@ func isSystemGoroutine(gp *g, fixed bool) bool {
 	if !f.valid() {
 		return false
 	}
-	if f.funcID == abi.FuncID_runtime_main || f.funcID == abi.FuncID_handleAsyncEvent {
+	if f.funcID == abi.FuncID_runtime_main || f.funcID == abi.FuncID_corostart || f.funcID == abi.FuncID_handleAsyncEvent {
 		return false
 	}
 	if f.funcID == abi.FuncID_runfinq {
-- 
GitLab